From e3b2be6b9c18ef1775670d196b2a177d4cebfe56 Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Wed, 3 Jul 2024 10:59:25 +0800 Subject: [PATCH 01/10] implement risc zero accelerator on p256 --- p256/Cargo.toml | 1 + p256/src/arithmetic.rs | 54 +++- p256/src/arithmetic/field.rs | 306 +++++++----------- p256/src/arithmetic/field/field32.rs | 280 ++++++++++++++++ p256/src/arithmetic/field/field64.rs | 213 ++++++++++++ p256/src/arithmetic/field/field_risc0.rs | 391 +++++++++++++++++++++++ p256/src/arithmetic/scalar.rs | 86 +++-- p256/src/arithmetic/scalar/scalar32.rs | 387 ++++++++++++++-------- p256/src/arithmetic/scalar/scalar64.rs | 177 +++++----- p256/src/arithmetic/util.rs | 72 ----- 10 files changed, 1459 insertions(+), 508 deletions(-) create mode 100644 p256/src/arithmetic/field/field32.rs create mode 100644 p256/src/arithmetic/field/field64.rs create mode 100644 p256/src/arithmetic/field/field_risc0.rs delete mode 100644 p256/src/arithmetic/util.rs diff --git a/p256/Cargo.toml b/p256/Cargo.toml index 59b98444..97f394c0 100644 --- a/p256/Cargo.toml +++ b/p256/Cargo.toml @@ -17,6 +17,7 @@ edition = "2021" rust-version = "1.65" [dependencies] +cfg-if = "1.0" elliptic-curve = { version = "0.13.8", default-features = false, features = ["hazmat", "sec1"] } # optional dependencies diff --git a/p256/src/arithmetic.rs b/p256/src/arithmetic.rs index 7cdf8b1d..4120e93d 100644 --- a/p256/src/arithmetic.rs +++ b/p256/src/arithmetic.rs @@ -8,11 +8,10 @@ pub(crate) mod field; #[cfg(feature = "hash2curve")] mod hash2curve; pub(crate) mod scalar; -pub(crate) mod util; use self::{field::FieldElement, scalar::Scalar}; use crate::NistP256; -use elliptic_curve::{CurveArithmetic, PrimeCurveArithmetic}; +use elliptic_curve::{bigint::U256, CurveArithmetic, PrimeCurveArithmetic}; use primeorder::{point_arithmetic, PrimeCurveParams}; /// Elliptic curve point in affine coordinates. @@ -39,10 +38,13 @@ impl PrimeCurveParams for NistP256 { type PointArithmetic = point_arithmetic::EquationAIsMinusThree; /// a = -3 - const EQUATION_A: FieldElement = FieldElement::from_u64(3).neg(); + const EQUATION_A: FieldElement = FieldElement(U256::from_be_hex( + "FFFFFFFC00000004000000000000000000000003FFFFFFFFFFFFFFFFFFFFFFFC", + )); - const EQUATION_B: FieldElement = - FieldElement::from_hex("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b"); + const EQUATION_B: FieldElement = FieldElement(U256::from_be_hex( + "DC30061D04874834E5A220ABF7212ED6ACF005CD78843090D89CDF6229C4BDDF", + )); /// Base point of P-256. /// @@ -53,7 +55,45 @@ impl PrimeCurveParams for NistP256 { /// Gᵧ = 4fe342e2 fe1a7f9b 8ee7eb4a 7c0f9e16 2bce3357 6b315ece cbb64068 37bf51f5 /// ``` const GENERATOR: (FieldElement, FieldElement) = ( - FieldElement::from_hex("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296"), - FieldElement::from_hex("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5"), + FieldElement(U256::from_be_hex( + "18905F76A53755C679FB732B7762251075BA95FC5FEDB60179E730D418A9143C", + )), + FieldElement(U256::from_be_hex( + "8571FF1825885D85D2E88688DD21F3258B4AB8E4BA19E45CDDF25357CE95560A", + )), ); } + +#[cfg(test)] +mod tests { + use super::FieldElement; + use crate::NistP256; + use primeorder::PrimeCurveParams; + + #[test] + fn equation_a_constant() { + let equation_a = FieldElement::from_u64(3).neg(); + assert_eq!(equation_a, NistP256::EQUATION_A); + } + + #[test] + fn equation_b_constant() { + let equation_b = FieldElement::from_hex( + "5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", + ); + assert_eq!(equation_b, NistP256::EQUATION_B); + } + + #[test] + fn generator_constant() { + let generator: (FieldElement, FieldElement) = ( + FieldElement::from_hex( + "6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", + ), + FieldElement::from_hex( + "4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", + ), + ); + assert_eq!(generator, NistP256::GENERATOR); + } +} diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index eb54746d..256b66ab 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -2,17 +2,28 @@ #![allow(clippy::assign_op_pattern, clippy::op_ref)] -use crate::{ - arithmetic::util::{adc, mac, sbb, u256_to_u64x4, u64x4_to_u256}, - FieldBytes, -}; +#[cfg_attr( + all(target_os = "zkvm", target_arch = "riscv32"), + path = "field/field_risc0.rs" +)] +#[cfg_attr( + all( + not(all(target_os = "zkvm", target_arch = "riscv32")), + target_pointer_width = "32" + ), + path = "field/field32.rs" +)] +#[cfg_attr(target_pointer_width = "64", path = "field/field64.rs")] +mod field_impl; + +use crate::FieldBytes; use core::{ iter::{Product, Sum}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; use elliptic_curve::ops::Invert; use elliptic_curve::{ - bigint::{ArrayEncoding, U256}, + bigint::{ArrayEncoding, U256, U512}, ff::{Field, PrimeField}, rand_core::RngCore, subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption}, @@ -73,7 +84,7 @@ impl FieldElement { } /// Convert a `u64` into a [`FieldElement`]. - pub const fn from_u64(w: u64) -> Self { + pub fn from_u64(w: u64) -> Self { Self::from_uint_unchecked(U256::from_u64(w)) } @@ -82,7 +93,8 @@ impl FieldElement { /// Does *not* perform a check that the field element does not overflow the order. /// /// This method is primarily intended for defining internal constants. - pub(crate) const fn from_hex(hex: &str) -> Self { + #[allow(dead_code)] + pub(crate) fn from_hex(hex: &str) -> Self { Self::from_uint_unchecked(U256::from_be_hex(hex)) } @@ -91,7 +103,7 @@ impl FieldElement { /// Does *not* perform a check that the field element does not overflow the order. /// /// Used incorrectly this can lead to invalid results! - pub(crate) const fn from_uint_unchecked(w: U256) -> Self { + pub(crate) fn from_uint_unchecked(w: U256) -> Self { Self(w).to_montgomery() } @@ -116,33 +128,27 @@ impl FieldElement { /// Returns self + rhs mod p pub const fn add(&self, rhs: &Self) -> Self { - let a = u256_to_u64x4(self.0); - let b = u256_to_u64x4(rhs.0); - - // Bit 256 of p is set, so addition can result in five words. - let (w0, carry) = adc(a[0], b[0], 0); - let (w1, carry) = adc(a[1], b[1], carry); - let (w2, carry) = adc(a[2], b[2], carry); - let (w3, w4) = adc(a[3], b[3], carry); - - // Attempt to subtract the modulus, to ensure the result is in the field. - let modulus = u256_to_u64x4(MODULUS.0); - let (result, _) = Self::sub_inner( - w0, w1, w2, w3, w4, modulus[0], modulus[1], modulus[2], modulus[3], 0, - ); - result + Self(field_impl::add(self.0, rhs.0)) + } + + /// Multiplies by a single-limb integer. + /// Multiplies the magnitude by the same value. + pub fn mul_single(&self, rhs: u32) -> Self { + Self(field_impl::mul_single(self.0, rhs)) } /// Returns 2*self. - pub const fn double(&self) -> Self { - self.add(self) + pub fn double(&self) -> Self { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + self.mul_single(2) + } else { + self.add(self) + } } /// Returns self - rhs mod p pub const fn sub(&self, rhs: &Self) -> Self { - let a = u256_to_u64x4(self.0); - let b = u256_to_u64x4(rhs.0); - Self::sub_inner(a[0], a[1], a[2], a[3], 0, b[0], b[1], b[2], b[3], 0).0 + Self(field_impl::sub(self.0, rhs.0)) } /// Negate element. @@ -150,182 +156,31 @@ impl FieldElement { Self::sub(&Self::ZERO, self) } - fn from_bytes_wide(bytes: [u8; 64]) -> Self { - #[allow(clippy::unwrap_used)] - FieldElement::montgomery_reduce( - u64::from_be_bytes(bytes[0..8].try_into().unwrap()), - u64::from_be_bytes(bytes[8..16].try_into().unwrap()), - u64::from_be_bytes(bytes[16..24].try_into().unwrap()), - u64::from_be_bytes(bytes[24..32].try_into().unwrap()), - u64::from_be_bytes(bytes[32..40].try_into().unwrap()), - u64::from_be_bytes(bytes[40..48].try_into().unwrap()), - u64::from_be_bytes(bytes[48..56].try_into().unwrap()), - u64::from_be_bytes(bytes[56..64].try_into().unwrap()), - ) - } - - #[inline] - #[allow(clippy::too_many_arguments)] - const fn sub_inner( - l0: u64, - l1: u64, - l2: u64, - l3: u64, - l4: u64, - r0: u64, - r1: u64, - r2: u64, - r3: u64, - r4: u64, - ) -> (Self, u64) { - let (w0, borrow) = sbb(l0, r0, 0); - let (w1, borrow) = sbb(l1, r1, borrow); - let (w2, borrow) = sbb(l2, r2, borrow); - let (w3, borrow) = sbb(l3, r3, borrow); - let (_, borrow) = sbb(l4, r4, borrow); - - // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise - // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the - // modulus. - let modulus = u256_to_u64x4(MODULUS.0); - let (w0, carry) = adc(w0, modulus[0] & borrow, 0); - let (w1, carry) = adc(w1, modulus[1] & borrow, carry); - let (w2, carry) = adc(w2, modulus[2] & borrow, carry); - let (w3, _) = adc(w3, modulus[3] & borrow, carry); - - (Self(u64x4_to_u256([w0, w1, w2, w3])), borrow) - } - - /// Montgomery Reduction - /// - /// The general algorithm is: - /// ```text - /// A <- input (2n b-limbs) - /// for i in 0..n { - /// k <- A[i] p' mod b - /// A <- A + k p b^i - /// } - /// A <- A / b^n - /// if A >= p { - /// A <- A - p - /// } - /// ``` - /// - /// For secp256r1, we have the following simplifications: - /// - /// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. - /// - /// - The first limb of p is 2^64 - 1; multiplications by this limb can be simplified - /// to a shift and subtraction: - /// ```text - /// a_i * (2^64 - 1) = a_i * 2^64 - a_i = (a_i << 64) - a_i - /// ``` - /// However, because `p' = 1`, the first limb of p is multiplied by limb i of the - /// intermediate A and then immediately added to that same limb, so we simply - /// initialize the carry to limb i of the intermediate. - /// - /// - The third limb of p is zero, so we can ignore any multiplications by it and just - /// add the carry. - /// - /// References: - /// - Handbook of Applied Cryptography, Chapter 14 - /// Algorithm 14.32 - /// http://cacr.uwaterloo.ca/hac/about/chap14.pdf - /// - /// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 - /// Algorithm 7) Montgomery Word-by-Word Reduction - /// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf - #[inline] - #[allow(clippy::too_many_arguments)] - const fn montgomery_reduce( - r0: u64, - r1: u64, - r2: u64, - r3: u64, - r4: u64, - r5: u64, - r6: u64, - r7: u64, - ) -> Self { - let modulus = u256_to_u64x4(MODULUS.0); - - let (r1, carry) = mac(r1, r0, modulus[1], r0); - let (r2, carry) = adc(r2, 0, carry); - let (r3, carry) = mac(r3, r0, modulus[3], carry); - let (r4, carry2) = adc(r4, 0, carry); - - let (r2, carry) = mac(r2, r1, modulus[1], r1); - let (r3, carry) = adc(r3, 0, carry); - let (r4, carry) = mac(r4, r1, modulus[3], carry); - let (r5, carry2) = adc(r5, carry2, carry); - - let (r3, carry) = mac(r3, r2, modulus[1], r2); - let (r4, carry) = adc(r4, 0, carry); - let (r5, carry) = mac(r5, r2, modulus[3], carry); - let (r6, carry2) = adc(r6, carry2, carry); - - let (r4, carry) = mac(r4, r3, modulus[1], r3); - let (r5, carry) = adc(r5, 0, carry); - let (r6, carry) = mac(r6, r3, modulus[3], carry); - let (r7, r8) = adc(r7, carry2, carry); - - // Result may be within MODULUS of the correct value - let (result, _) = Self::sub_inner( - r4, r5, r6, r7, r8, modulus[0], modulus[1], modulus[2], modulus[3], 0, - ); - result - } - /// Translate a field element out of the Montgomery domain. #[inline] pub(crate) const fn to_canonical(self) -> Self { - let w = u256_to_u64x4(self.0); - FieldElement::montgomery_reduce(w[0], w[1], w[2], w[3], 0, 0, 0, 0) + Self(field_impl::to_canonical(self.0)) } /// Translate a field element into the Montgomery domain. #[inline] - pub(crate) const fn to_montgomery(self) -> Self { + pub(crate) fn to_montgomery(self) -> Self { Self::multiply(&self, &R2) } /// Returns self * rhs mod p - pub const fn multiply(&self, rhs: &Self) -> Self { - // Schoolbook multiplication. - let a = u256_to_u64x4(self.0); - let b = u256_to_u64x4(rhs.0); - - let (w0, carry) = mac(0, a[0], b[0], 0); - let (w1, carry) = mac(0, a[0], b[1], carry); - let (w2, carry) = mac(0, a[0], b[2], carry); - let (w3, w4) = mac(0, a[0], b[3], carry); - - let (w1, carry) = mac(w1, a[1], b[0], 0); - let (w2, carry) = mac(w2, a[1], b[1], carry); - let (w3, carry) = mac(w3, a[1], b[2], carry); - let (w4, w5) = mac(w4, a[1], b[3], carry); - - let (w2, carry) = mac(w2, a[2], b[0], 0); - let (w3, carry) = mac(w3, a[2], b[1], carry); - let (w4, carry) = mac(w4, a[2], b[2], carry); - let (w5, w6) = mac(w5, a[2], b[3], carry); - - let (w3, carry) = mac(w3, a[3], b[0], 0); - let (w4, carry) = mac(w4, a[3], b[1], carry); - let (w5, carry) = mac(w5, a[3], b[2], carry); - let (w6, w7) = mac(w6, a[3], b[3], carry); - - FieldElement::montgomery_reduce(w0, w1, w2, w3, w4, w5, w6, w7) + pub fn multiply(&self, rhs: &Self) -> Self { + Self(field_impl::mul(self.0, rhs.0)) } /// Returns self * self mod p - pub const fn square(&self) -> Self { + pub fn square(&self) -> Self { // Schoolbook multiplication. self.multiply(self) } /// Returns self^(2^n) mod p - const fn sqn(&self, n: usize) -> Self { + fn sqn(&self, n: usize) -> Self { let mut x = *self; let mut i = 0; while i < n { @@ -361,7 +216,7 @@ impl FieldElement { /// Returns the multiplicative inverse of self. /// /// Does not check that self is non-zero. - const fn invert_unchecked(&self) -> Self { + pub fn invert_unchecked(&self) -> Self { // We need to find b such that b * a ≡ 1 mod p. As we are in a prime // field, we can apply Fermat's Little Theorem: // @@ -420,7 +275,8 @@ impl Field for FieldElement { // negligible bias from the uniform distribution. let mut buf = [0; 64]; rng.fill_bytes(&mut buf); - FieldElement::from_bytes_wide(buf) + let buf = U512::from_be_slice(&buf); + Self(field_impl::from_bytes_wide(buf)) } #[must_use] @@ -452,13 +308,22 @@ impl PrimeField for FieldElement { const MODULUS: &'static str = MODULUS_HEX; const NUM_BITS: u32 = 256; const CAPACITY: u32 = 255; - const TWO_INV: Self = Self::from_u64(2).invert_unchecked(); - const MULTIPLICATIVE_GENERATOR: Self = Self::from_u64(6); + const TWO_INV: Self = Self(U256::from_be_hex( + "8000000000000000000000000000000000000000000000000000000000000000", + )); + const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_be_hex( + "00000005FFFFFFF9FFFFFFFFFFFFFFFFFFFFFFFA000000000000000000000006", + )); const S: u32 = 1; - const ROOT_OF_UNITY: Self = - Self::from_hex("ffffffff00000001000000000000000000000000fffffffffffffffffffffffe"); - const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unchecked(); - const DELTA: Self = Self::from_u64(36); + const ROOT_OF_UNITY: Self = Self(U256::from_be_hex( + "FFFFFFFE00000002000000000000000000000001FFFFFFFFFFFFFFFFFFFFFFFE", + )); + const ROOT_OF_UNITY_INV: Self = Self(U256::from_be_hex( + "FFFFFFFE00000002000000000000000000000001FFFFFFFFFFFFFFFFFFFFFFFE", + )); + const DELTA: Self = Self(U256::from_be_hex( + "00000023FFFFFFDBFFFFFFFFFFFFFFFFFFFFFFDC000000000000000000000024", + )); fn from_repr(bytes: FieldBytes) -> CtOption { Self::from_bytes(bytes) @@ -642,7 +507,7 @@ impl Neg for &FieldElement { impl Sum for FieldElement { fn sum>(iter: I) -> Self { - iter.reduce(core::ops::Add::add).unwrap_or(Self::ZERO) + iter.reduce(Add::add).unwrap_or(Self::ZERO) } } @@ -654,7 +519,7 @@ impl<'a> Sum<&'a FieldElement> for FieldElement { impl Product for FieldElement { fn product>(iter: I) -> Self { - iter.reduce(core::ops::Mul::mul).unwrap_or(Self::ONE) + iter.reduce(Mul::mul).unwrap_or(Self::ONE) } } @@ -666,9 +531,14 @@ impl<'a> Product<&'a FieldElement> for FieldElement { #[cfg(test)] mod tests { - use super::{u64x4_to_u256, FieldElement}; + use super::FieldElement; use crate::{test_vectors::field::DBL_TEST_VECTORS, FieldBytes}; use core::ops::Mul; + use elliptic_curve::ff::PrimeField; + + #[cfg(target_pointer_width = "64")] + use crate::U256; + #[cfg(target_pointer_width = "64")] use proptest::{num::u64::ANY, prelude::*}; #[test] @@ -679,6 +549,45 @@ mod tests { assert_eq!(one.add(&zero), one); } + #[test] + fn root_of_unity_constant() { + let root_of_unity = FieldElement::from_hex( + "ffffffff00000001000000000000000000000000fffffffffffffffffffffffe", + ); + let root_of_unity_inv = root_of_unity.invert_unchecked(); + assert_eq!(root_of_unity, FieldElement::ROOT_OF_UNITY); + assert_eq!(root_of_unity_inv, FieldElement::ROOT_OF_UNITY_INV); + assert_eq!( + (FieldElement::ROOT_OF_UNITY * FieldElement::ROOT_OF_UNITY_INV), + FieldElement::ONE + ) + } + + #[test] + fn two_inv_constant() { + let number = FieldElement::from_u64(2).invert_unchecked(); + assert_eq!(number, FieldElement::TWO_INV); + assert_eq!( + (FieldElement::from(2u64) * FieldElement::TWO_INV), + FieldElement::ONE + ); + } + + #[test] + fn multiplicative_generator_constant() { + let multiplicative_generator = FieldElement::from_u64(6); + assert_eq!( + multiplicative_generator, + FieldElement::MULTIPLICATIVE_GENERATOR + ); + } + + #[test] + fn delta_constant() { + let delta = FieldElement::from_u64(36); + assert_eq!(delta, FieldElement::DELTA); + } + #[test] fn one_is_multiplicative_identity() { let one = FieldElement::ONE; @@ -784,6 +693,7 @@ mod tests { assert_eq!(four.sqrt().unwrap(), two); } + #[cfg(target_pointer_width = "64")] proptest! { /// This checks behaviour well within the field ranges, because it doesn't set the /// highest limb. @@ -796,8 +706,8 @@ mod tests { b1 in ANY, b2 in ANY, ) { - let a = FieldElement(u64x4_to_u256([a0, a1, a2, 0])); - let b = FieldElement(u64x4_to_u256([b0, b1, b2, 0])); + let a = FieldElement(U256::from_words([a0, a1, a2, 0])); + let b = FieldElement(U256::from_words([b0, b1, b2, 0])); assert_eq!(a.add(&b).sub(&a), b); } } diff --git a/p256/src/arithmetic/field/field32.rs b/p256/src/arithmetic/field/field32.rs new file mode 100644 index 00000000..da9bd183 --- /dev/null +++ b/p256/src/arithmetic/field/field32.rs @@ -0,0 +1,280 @@ +//! 32-bit secp256r1 field element algorithms. + +use super::MODULUS; +use elliptic_curve::bigint::{Limb, U256, U512}; + +pub(super) const fn add(a: U256, b: U256) -> U256 { + let a = a.as_limbs(); + let b = b.as_limbs(); + + // Bit 256 of p is set, so addition can result in nine words. + // let (w0, carry) = adc(a[0], b[0], 0); + let (w0, carry) = a[0].adc(b[0], Limb::ZERO); + let (w1, carry) = a[1].adc(b[1], carry); + let (w2, carry) = a[2].adc(b[2], carry); + let (w3, carry) = a[3].adc(b[3], carry); + let (w4, carry) = a[4].adc(b[4], carry); + let (w5, carry) = a[5].adc(b[5], carry); + let (w6, carry) = a[6].adc(b[6], carry); + let (w7, w8) = a[7].adc(b[7], carry); + // Attempt to subtract the modulus, to ensure the result is in the field. + let modulus = MODULUS.0.as_limbs(); + + let (result, _) = sub_inner( + [w0, w1, w2, w3, w4, w5, w6, w7, w8], + [ + modulus[0], + modulus[1], + modulus[2], + modulus[3], + modulus[4], + modulus[5], + modulus[6], + modulus[7], + Limb::ZERO, + ], + ); + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) +} + +/// Multiplies by a single-limb integer. +/// Multiplies the magnitude by the same value. +pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { + let mut result = U256::ZERO; + for _i in 0..rhs { + result = add(a, a) + } + result +} + +/// Returns self * rhs mod p +pub(super) const fn mul(a: U256, b: U256) -> U256 { + let (lo, hi): (U256, U256) = a.mul_wide(&b); + montgomery_reduce(lo, hi) +} + +pub(super) const fn sub(a: U256, b: U256) -> U256 { + let a = a.as_limbs(); + let b = b.as_limbs(); + + let (result, _) = sub_inner( + [a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], Limb::ZERO], + [b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], Limb::ZERO], + ); + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) +} + +#[inline] +pub(super) const fn to_canonical(a: U256) -> U256 { + montgomery_reduce(a, U256::ZERO) +} + +pub(super) fn from_bytes_wide(a: U512) -> U256 { + let words = a.to_limbs(); + montgomery_reduce( + U256::new([ + words[8], words[9], words[10], words[11], words[12], words[13], words[14], words[15], + ]), + U256::new([ + words[0], words[1], words[2], words[3], words[4], words[5], words[6], words[7], + ]), + ) +} + +/// Montgomery Reduction +/// +/// The general algorithm is: +/// ```text +/// A <- input (2n b-limbs) +/// for i in 0..n { +/// k <- A[i] p' mod b +/// A <- A + k p b^i +/// } +/// A <- A / b^n +/// if A >= p { +/// A <- A - p +/// } +/// ``` +/// +/// For secp256r1, with a 32-bit arithmetic, we have the following +/// simplifications: +/// +/// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. +/// +/// - The first limb of p is 2^32 - 1; multiplications by this limb can be simplified +/// to a shift and subtraction: +/// ```text +/// a_i * (2^32 - 1) = a_i * 2^32 - a_i = (a_i << 32) - a_i +/// ``` +/// However, because `p' = 1`, the first limb of p is multiplied by limb i of the +/// intermediate A and then immediately added to that same limb, so we simply +/// initialize the carry to limb i of the intermediate. +/// +/// The same applies for the second and third limb. +/// +/// - The fourth limb of p is zero, so we can ignore any multiplications by it and just +/// add the carry. +/// +/// The same applies for the fifth and sixth limb. +/// +/// - The seventh limb of p is one, so we can substitute a `mac` operation with a `adc` one. +/// +/// References: +/// - Handbook of Applied Cryptography, Chapter 14 +/// Algorithm 14.32 +/// http://cacr.uwaterloo.ca/hac/about/chap14.pdf +/// +/// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 +/// Algorithm 7) Montgomery Word-by-Word Reduction +/// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf +#[inline] +#[allow(clippy::too_many_arguments)] +pub(super) const fn montgomery_reduce(lo: U256, hi: U256) -> U256 { + let lo = lo.as_limbs(); + let hi = hi.as_limbs(); + + let a0 = lo[0]; + let a1 = lo[1]; + let a2 = lo[2]; + let a3 = lo[3]; + let a4 = lo[4]; + let a5 = lo[5]; + let a6 = lo[6]; + let a7 = lo[7]; + let a8 = hi[0]; + let a9 = hi[1]; + let a10 = hi[2]; + let a11 = hi[3]; + let a12 = hi[4]; + let a13 = hi[5]; + let a14 = hi[6]; + let a15 = hi[7]; + + let modulus = MODULUS.0.as_limbs(); + + /* + * let (a0, c) = (0, a0); + * let (a1, c) = (a1, a0); + * let (a2, c) = (a2, a0); + */ + let (a3, carry) = a3.adc(Limb::ZERO, a0); + let (a4, carry) = a4.adc(Limb::ZERO, carry); + let (a5, carry) = a5.adc(Limb::ZERO, carry); + let (a6, carry) = a6.adc(a0, carry); + // NOTE `modulus[7]` is 2^32 - 1, this could be optimized to `adc` and `sbb` + // but multiplication costs 1 clock-cycle on several architectures, + // thanks to parallelization + let (a7, carry) = a7.mac(a0, modulus[7], carry); + /* optimization with only adc and sbb + * let (x, _) = sbb(0, a0, 0); + * let (y, _) = sbb(a0, 0, (a0 != 0) as u32); + * + * (a7, carry) = adc(a7, x, carry); + * (carry, _) = adc(y, 0, carry); + */ + let (a8, carry2) = a8.adc(Limb::ZERO, carry); + + let (a4, carry) = a4.adc(Limb::ZERO, a1); + let (a5, carry) = a5.adc(Limb::ZERO, carry); + let (a6, carry) = a6.adc(Limb::ZERO, carry); + let (a7, carry) = a7.adc(a1, carry); + let (a8, carry) = a8.mac(a1, modulus[7], carry); + let (a9, carry2) = a9.adc(carry2, carry); + + let (a5, carry) = a5.adc(Limb::ZERO, a2); + let (a6, carry) = a6.adc(Limb::ZERO, carry); + let (a7, carry) = a7.adc(Limb::ZERO, carry); + let (a8, carry) = a8.adc(a2, carry); + let (a9, carry) = a9.mac(a2, modulus[7], carry); + let (a10, carry2) = a10.adc(carry2, carry); + + let (a6, carry) = a6.adc(Limb::ZERO, a3); + let (a7, carry) = a7.adc(Limb::ZERO, carry); + let (a8, carry) = a8.adc(Limb::ZERO, carry); + let (a9, carry) = a9.adc(a3, carry); + let (a10, carry) = a10.mac(a3, modulus[7], carry); + let (a11, carry2) = a11.adc(carry2, carry); + + let (a7, carry) = a7.adc(Limb::ZERO, a4); + let (a8, carry) = a8.adc(Limb::ZERO, carry); + let (a9, carry) = a9.adc(Limb::ZERO, carry); + let (a10, carry) = a10.adc(a4, carry); + let (a11, carry) = a11.mac(a4, modulus[7], carry); + let (a12, carry2) = a12.adc(carry2, carry); + + let (a8, carry) = a8.adc(Limb::ZERO, a5); + let (a9, carry) = a9.adc(Limb::ZERO, carry); + let (a10, carry) = a10.adc(Limb::ZERO, carry); + let (a11, carry) = a11.adc(a5, carry); + let (a12, carry) = a12.mac(a5, modulus[7], carry); + let (a13, carry2) = a13.adc(carry2, carry); + + let (a9, carry) = a9.adc(Limb::ZERO, a6); + let (a10, carry) = a10.adc(Limb::ZERO, carry); + let (a11, carry) = a11.adc(Limb::ZERO, carry); + let (a12, carry) = a12.adc(a6, carry); + let (a13, carry) = a13.mac(a6, modulus[7], carry); + let (a14, carry2) = a14.adc(carry2, carry); + + let (a10, carry) = a10.adc(Limb::ZERO, a7); + let (a11, carry) = a11.adc(Limb::ZERO, carry); + let (a12, carry) = a12.adc(Limb::ZERO, carry); + let (a13, carry) = a13.adc(a7, carry); + let (a14, carry) = a14.mac(a7, modulus[7], carry); + let (a15, a16) = a15.adc(carry2, carry); + + // Result may be within MODULUS of the correct value + let (result, _) = sub_inner( + [a8, a9, a10, a11, a12, a13, a14, a15, a16], + [ + modulus[0], + modulus[1], + modulus[2], + modulus[3], + modulus[4], + modulus[5], + modulus[6], + modulus[7], + Limb::ZERO, + ], + ); + + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) +} + +#[inline] +#[allow(clippy::too_many_arguments)] +const fn sub_inner(l: [Limb; 9], r: [Limb; 9]) -> ([Limb; 8], Limb) { + let (w0, borrow) = l[0].sbb(r[0], Limb::ZERO); + let (w1, borrow) = l[1].sbb(r[1], borrow); + let (w2, borrow) = l[2].sbb(r[2], borrow); + let (w3, borrow) = l[3].sbb(r[3], borrow); + let (w4, borrow) = l[4].sbb(r[4], borrow); + let (w5, borrow) = l[5].sbb(r[5], borrow); + let (w6, borrow) = l[6].sbb(r[6], borrow); + let (w7, borrow) = l[7].sbb(r[7], borrow); + let (_, borrow) = l[8].sbb(r[8], borrow); + + // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise + // borrow = 0x000...000. Thus, we use it as a mask to conditionally add + // the modulus. + + let modulus = MODULUS.0.as_limbs(); + + let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO); + let (w1, carry) = w1.adc(modulus[1].bitand(borrow), carry); + let (w2, carry) = w2.adc(modulus[2].bitand(borrow), carry); + let (w3, carry) = w3.adc(modulus[3].bitand(borrow), carry); + let (w4, carry) = w4.adc(modulus[4].bitand(borrow), carry); + let (w5, carry) = w5.adc(modulus[5].bitand(borrow), carry); + let (w6, carry) = w6.adc(modulus[6].bitand(borrow), carry); + let (w7, _) = w7.adc(modulus[7].bitand(borrow), carry); + + ([w0, w1, w2, w3, w4, w5, w6, w7], borrow) +} diff --git a/p256/src/arithmetic/field/field64.rs b/p256/src/arithmetic/field/field64.rs new file mode 100644 index 00000000..652d7391 --- /dev/null +++ b/p256/src/arithmetic/field/field64.rs @@ -0,0 +1,213 @@ +//! 64-bit secp256r1 field element algorithms. + +use super::MODULUS; +use elliptic_curve::bigint::{Limb, U256, U512}; + +pub(super) const fn add(a: U256, b: U256) -> U256 { + let a = a.as_limbs(); + let b = b.as_limbs(); + + // Bit 256 of p is set, so addition can result in five words. + let (w0, carry) = a[0].adc(b[0], Limb::ZERO); + let (w1, carry) = a[1].adc(b[1], carry); + let (w2, carry) = a[2].adc(b[2], carry); + let (w3, w4) = a[3].adc(b[3], carry); + // let (w0, carry) = adc(a[0], b[0], 0); + // let (w1, carry) = adc(a[1], b[1], carry); + // let (w2, carry) = adc(a[2], b[2], carry); + // let (w3, w4) = adc(a[3], b[3], carry); + + // Attempt to subtract the modulus, to ensure the result is in the field + let modulus = MODULUS.0.as_limbs(); + + let (result, _) = sub_inner( + [w0, w1, w2, w3, w4], + [modulus[0], modulus[1], modulus[2], modulus[3], Limb::ZERO], + ); + U256::new([result[0], result[1], result[2], result[3]]) +} + +/// Multiplies by a single-limb integer.P +/// Multiplies the magnitude by the same value. +pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { + let mut result = U256::ZERO; + for _i in 0..rhs { + result = add(a, a) + } + result +} + +/// Returns self * rhs mod p +pub(super) fn mul(a: U256, b: U256) -> U256 { + let (lo, hi): (U256, U256) = a.mul_wide(&b); + montgomery_reduce(lo, hi) +} + +pub(super) const fn sub(a: U256, b: U256) -> U256 { + let a = a.as_limbs(); + let b = b.as_limbs(); + + let (result, _) = sub_inner( + [a[0], a[1], a[2], a[3], Limb::ZERO], + [b[0], b[1], b[2], b[3], Limb::ZERO], + ); + U256::new([result[0], result[1], result[2], result[3]]) +} + +#[inline] +pub(super) const fn to_canonical(a: U256) -> U256 { + montgomery_reduce(a, U256::ZERO) +} + +pub(super) fn from_bytes_wide(a: U512) -> U256 { + let words = a.to_limbs(); + montgomery_reduce( + U256::new([words[4], words[5], words[6], words[7]]), + U256::new([words[0], words[1], words[2], words[3]]), + ) +} + +/// Montgomery Reduction +/// +/// The general algorithm is: +/// ```text +/// A <- input (2n b-limbs) +/// for i in 0..n { +/// k <- A[i] p' mod b +/// A <- A + k p b^i +/// } +/// A <- A / b^n +/// if A >= p { +/// A <- A - p +/// } +/// ``` +/// +/// For secp256r1, with a 64-bit arithmetic, we have the following +/// simplifications: +/// +/// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. +/// +/// - The first limb of p is 2^64 - 1; multiplications by this limb can be simplified +/// to a shift and subtraction: +/// ```text +/// a_i * (2^64 - 1) = a_i * 2^64 - a_i = (a_i << 64) - a_i +/// ``` +/// However, because `p' = 1`, the first limb of p is multiplied by limb i of the +/// intermediate A and then immediately added to that same limb, so we simply +/// initialize the carry to limb i of the intermediate. +/// +/// - The third limb of p is zero, so we can ignore any multiplications by it and just +/// add the carry. +/// +/// References: +/// - Handbook of Applied Cryptography, Chapter 14 +/// Algorithm 14.32 +/// http://cacr.uwaterloo.ca/hac/about/chap14.pdf +/// +/// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 +/// Algorithm 7) Montgomery Word-by-Word Reduction +/// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf +#[inline] +#[allow(clippy::too_many_arguments)] +pub(super) const fn montgomery_reduce(lo: U256, hi: U256) -> U256 { + let lo = lo.as_limbs(); + let hi = hi.as_limbs(); + + let a0 = lo[0]; + let a1 = lo[1]; + let a2 = lo[2]; + let a3 = lo[3]; + let a4 = hi[0]; + let a5 = hi[1]; + let a6 = hi[2]; + let a7 = hi[3]; + + let modulus = MODULUS.0.as_limbs(); + + /* + let (a1, carry) = mac(a1, a0, modulus[1], a0); + let (a2, carry) = adc(a2, 0, carry); + let (a3, carry) = mac(a3, a0, modulus[3], carry); + let (a4, carry2) = adc(a4, 0, carry); + + let (a2, carry) = mac(a2, a1, modulus[1], a1); + let (a3, carry) = adc(a3, 0, carry); + let (a4, carry) = mac(a4, a1, modulus[3], carry); + let (a5, carry2) = adc(a5, carry2, carry); + + let (a3, carry) = mac(a3, a2, modulus[1], a2); + let (a4, carry) = adc(a4, 0, carry); + let (a5, carry) = mac(a5, a2, modulus[3], carry); + let (a6, carry2) = adc(a6, carry2, carry); + + let (a4, carry) = mac(a4, a3, modulus[1], a3); + let (a5, carry) = adc(a5, 0, carry); + let (a6, carry) = mac(a6, a3, modulus[3], carry); + let (a7, a8) = adc(a7, carry2, carry); + */ + + let (a1, carry) = a1.mac(a0, modulus[1], a0); + let (a2, carry) = a2.adc(Limb::ZERO, carry); + let (a3, carry) = a3.mac(a0, modulus[3], carry); + let (a4, carry2) = a4.adc(Limb::ZERO, carry); + + let (a2, carry) = a2.mac(a1, modulus[1], a1); + let (a3, carry) = a3.adc(Limb::ZERO, carry); + let (a4, carry) = a4.mac(a1, modulus[3], carry); + let (a5, carry2) = a5.adc(carry2, carry); + + let (a3, carry) = a3.mac(a2, modulus[1], a2); + let (a4, carry) = a4.adc(Limb::ZERO, carry); + let (a5, carry) = a5.mac(a2, modulus[3], carry); + let (a6, carry2) = a6.adc(carry2, carry); + + let (a4, carry) = a4.mac(a3, modulus[1], a3); + let (a5, carry) = a5.adc(Limb::ZERO, carry); + let (a6, carry) = a6.mac(a3, modulus[3], carry); + let (a7, a8) = a7.adc(carry2, carry); + + // Result may be within MODULUS of the correct value + let (result, _) = sub_inner( + [a4, a5, a6, a7, a8], + [modulus[0], modulus[1], modulus[2], modulus[3], Limb::ZERO], + ); + U256::new([result[0], result[1], result[2], result[3]]) +} + +#[inline] +#[allow(clippy::too_many_arguments)] +const fn sub_inner(l: [Limb; 5], r: [Limb; 5]) -> ([Limb; 4], Limb) { + /* + let (w0, borrow) = sbb(l[0], r[0], 0); + let (w1, borrow) = sbb(l[1], r[1], borrow); + let (w2, borrow) = sbb(l[2], r[2], borrow); + let (w3, borrow) = sbb(l[3], r[3], borrow); + let (_, borrow) = sbb(l[4], r[4], borrow); + */ + + let (w0, borrow) = l[0].sbb(r[0], Limb::ZERO); + let (w1, borrow) = l[1].sbb(r[1], borrow); + let (w2, borrow) = l[2].sbb(r[2], borrow); + let (w3, borrow) = l[3].sbb(r[3], borrow); + let (_, borrow) = l[4].sbb(r[4], borrow); + + // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise + // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the + // modulus. + + let modulus = MODULUS.0.as_limbs(); + + /* + let (w0, carry) = adc(w0, modulus[0] & borrow, 0); + let (w1, carry) = adc(w1, modulus[1] & borrow, carry); + let (w2, carry) = adc(w2, modulus[2] & borrow, carry); + let (w3, _) = adc(w3, modulus[3] & borrow, carry); + */ + + let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO); + let (w1, carry) = w1.adc(modulus[1].bitand(borrow), carry); + let (w2, carry) = w2.adc(modulus[2].bitand(borrow), carry); + let (w3, _) = w3.adc(modulus[3].bitand(borrow), carry); + + ([w0, w1, w2, w3], borrow) +} diff --git a/p256/src/arithmetic/field/field_risc0.rs b/p256/src/arithmetic/field/field_risc0.rs new file mode 100644 index 00000000..4fcf7d40 --- /dev/null +++ b/p256/src/arithmetic/field/field_risc0.rs @@ -0,0 +1,391 @@ +//! 64-bit secp256r1 field element algorithms. + +use super::{MODULUS, MODULUS_HEX}; +use elliptic_curve::bigint::{risc0, Limb, U128, U256, U512}; + +const MODULUS_256: U256 = U256::from_be_hex(MODULUS_HEX); + +pub(super) const fn add(a: U256, b: U256) -> U256 { + let a = a.as_limbs(); + let b = b.as_limbs(); + + // Bit 256 of p is set, so addition can result in nine words. + // let (w0, carry) = adc(a[0], b[0], 0); + let (w0, carry) = a[0].adc(b[0], Limb::ZERO); + let (w1, carry) = a[1].adc(b[1], carry); + let (w2, carry) = a[2].adc(b[2], carry); + let (w3, carry) = a[3].adc(b[3], carry); + let (w4, carry) = a[4].adc(b[4], carry); + let (w5, carry) = a[5].adc(b[5], carry); + let (w6, carry) = a[6].adc(b[6], carry); + let (w7, w8) = a[7].adc(b[7], carry); + // Attempt to subtract the modulus, to ensure the result is in the field. + let modulus = MODULUS.0.as_limbs(); + + let (result, _) = sub_inner( + [w0, w1, w2, w3, w4, w5, w6, w7, w8], + [ + modulus[0], + modulus[1], + modulus[2], + modulus[3], + modulus[4], + modulus[5], + modulus[6], + modulus[7], + Limb::ZERO, + ], + ); + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) +} + +/// Multiplies by a single-limb integer. +/// Multiplies the magnitude by the same value. +pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { + risc0::modmul_u256_denormalized( + &a, + &U256::from_words([rhs, 0, 0, 0, 0, 0, 0, 0]), + &MODULUS_256, + ) +} + +/// Wide multiplication of two 256-bit Uint values +/// Given 2 `U256` values a and b, we can decompose them into: +/// ```text +/// a = a1 * 2^(128) + a0 +/// b = b1 * 2^(128) + b1 +/// ``` +/// Then the product P = a * b can be expressed as: +/// ```text +/// P = a1 * b1 * 2^(256) + (a1 * b0 + a0 * b1) * 2^(128) + a0 * b0 +/// ``` +/// Hence we need to calculate the constants (a1 * b1), (a1 * b0), +/// (a0 * b1) and (a0 * b0) using RISC Zero's accelerator and +/// combine the results into a wide value for montgomery reduction +pub(super) fn mul_wide_256(a: U256, b: U256) -> (U256, U256) { + // Split each U256 into two U128 + let a0 = U128::from_words([ + a.as_words()[0], + a.as_words()[1], + a.as_words()[2], + a.as_words()[3], + ]); + let a1 = U128::from_words([ + a.as_words()[4], + a.as_words()[5], + a.as_words()[6], + a.as_words()[7], + ]); + let b0 = U128::from_words([ + b.as_words()[0], + b.as_words()[1], + b.as_words()[2], + b.as_words()[3], + ]); + let b1 = U128::from_words([ + b.as_words()[4], + b.as_words()[5], + b.as_words()[6], + b.as_words()[7], + ]); + + // Perform the four multiplications using RISC Zero Accelerator + let p0 = risc0::mul_wide_u128(&a0, &b0); + let p1 = risc0::mul_wide_u128(&a0, &b1); + let p2 = risc0::mul_wide_u128(&a1, &b0); + let p3 = risc0::mul_wide_u128(&a1, &b1); + + // Initialize the U512 result + let mut result = [0u32; 16]; + let mut carry = 0; + let mut carry12 = 0; + + // Copy p0 to result[0..8] + for i in 0..8 { + result[i] = p0.as_words()[i]; + } + + // Add p1 shifted left by 128 bits to result[4..12] + for i in 0..8 { + let (sum, c) = result[i + 4].overflowing_add(p1.as_words()[i]); + let (sum_with_carry, c2) = sum.overflowing_add(carry); + result[i + 4] = sum_with_carry; + carry = (c as u32) + (c2 as u32); + if i == 7 { + // We need to account for the carry for result[12] + carry12 = carry + carry12; + } + } + // Reset carry for next addition + carry = 0; + + // Add p2 shifted left by 128 bits to result[4..12] + for i in 0..8 { + let (sum, c) = result[i + 4].overflowing_add(p2.as_words()[i]); + let (sum_with_carry, c2) = sum.overflowing_add(carry); + result[i + 4] = sum_with_carry; + carry = (c as u32) + (c2 as u32); + if i == 7 { + // We need to account for the carry for result[12] + carry12 = carry + carry12; + } + } + // Reset carry for next addition + carry = 0; + + // Add p3 shifted left by 256 bits to result[8..16] + for i in 0..8 { + if i == 4 { + // Apply the carry that we accounted for + carry = carry12 + carry; + } + let (sum, c) = result[i + 8].overflowing_add(p3.as_words()[i]); + assert_eq!(sum, result[i + 8] + p3.as_words()[i]); + let (sum_with_carry, c2) = sum.overflowing_add(carry); + assert_eq!(sum_with_carry, sum + carry); + result[i + 8] = sum_with_carry; + carry = (c as u32) + (c2 as u32); + } + + let low = U256::from_words([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]); + let high = U256::from_words([ + result[8], result[9], result[10], result[11], result[12], result[13], result[14], + result[15], + ]); + + (low, high) +} + +/// Returns self * rhs mod p +pub(super) fn mul(a: U256, b: U256) -> U256 { + let (low, high) = mul_wide_256(a, b); + montgomery_reduce(low, high) +} + +pub(super) const fn sub(a: U256, b: U256) -> U256 { + let a = a.as_limbs(); + let b = b.as_limbs(); + + let (result, _) = sub_inner( + [a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], Limb::ZERO], + [b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], Limb::ZERO], + ); + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) +} + +#[inline] +pub(super) const fn to_canonical(a: U256) -> U256 { + montgomery_reduce(a, U256::ZERO) +} + +pub(super) fn from_bytes_wide(a: U512) -> U256 { + let words = a.to_limbs(); + montgomery_reduce( + U256::new([ + words[8], words[9], words[10], words[11], words[12], words[13], words[14], words[15], + ]), + U256::new([ + words[0], words[1], words[2], words[3], words[4], words[5], words[6], words[7], + ]), + ) +} + +/// Montgomery Reduction +/// +/// The general algorithm is: +/// ```text +/// A <- input (2n b-limbs) +/// for i in 0..n { +/// k <- A[i] p' mod b +/// A <- A + k p b^i +/// } +/// A <- A / b^n +/// if A >= p { +/// A <- A - p +/// } +/// ``` +/// +/// For secp256r1, with a 32-bit arithmetic, we have the following +/// simplifications: +/// +/// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. +/// +/// - The first limb of p is 2^32 - 1; multiplications by this limb can be simplified +/// to a shift and subtraction: +/// ```text +/// a_i * (2^32 - 1) = a_i * 2^32 - a_i = (a_i << 32) - a_i +/// ``` +/// However, because `p' = 1`, the first limb of p is multiplied by limb i of the +/// intermediate A and then immediately added to that same limb, so we simply +/// initialize the carry to limb i of the intermediate. +/// +/// The same applies for the second and third limb. +/// +/// - The fourth limb of p is zero, so we can ignore any multiplications by it and just +/// add the carry. +/// +/// The same applies for the fifth and sixth limb. +/// +/// - The seventh limb of p is one, so we can substitute a `mac` operation with a `adc` one. +/// +/// References: +/// - Handbook of Applied Cryptography, Chapter 14 +/// Algorithm 14.32 +/// http://cacr.uwaterloo.ca/hac/about/chap14.pdf +/// +/// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 +/// Algorithm 7) Montgomery Word-by-Word Reduction +/// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf +#[inline] +#[allow(clippy::too_many_arguments)] +pub(super) const fn montgomery_reduce(lo: U256, hi: U256) -> U256 { + let lo = lo.as_limbs(); + let hi = hi.as_limbs(); + + let a0 = lo[0]; + let a1 = lo[1]; + let a2 = lo[2]; + let a3 = lo[3]; + let a4 = lo[4]; + let a5 = lo[5]; + let a6 = lo[6]; + let a7 = lo[7]; + let a8 = hi[0]; + let a9 = hi[1]; + let a10 = hi[2]; + let a11 = hi[3]; + let a12 = hi[4]; + let a13 = hi[5]; + let a14 = hi[6]; + let a15 = hi[7]; + + let modulus = MODULUS.0.as_limbs(); + + /* + * let (a0, c) = (0, a0); + * let (a1, c) = (a1, a0); + * let (a2, c) = (a2, a0); + */ + let (a3, carry) = a3.adc(Limb::ZERO, a0); + let (a4, carry) = a4.adc(Limb::ZERO, carry); + let (a5, carry) = a5.adc(Limb::ZERO, carry); + let (a6, carry) = a6.adc(a0, carry); + // NOTE `modulus[7]` is 2^32 - 1, this could be optimized to `adc` and `sbb` + // but multiplication costs 1 clock-cycle on several architectures, + // thanks to parallelization + let (a7, carry) = a7.mac(a0, modulus[7], carry); + /* optimization with only adc and sbb + * let (x, _) = sbb(0, a0, 0); + * let (y, _) = sbb(a0, 0, (a0 != 0) as u32); + * + * (a7, carry) = adc(a7, x, carry); + * (carry, _) = adc(y, 0, carry); + */ + let (a8, carry2) = a8.adc(Limb::ZERO, carry); + + let (a4, carry) = a4.adc(Limb::ZERO, a1); + let (a5, carry) = a5.adc(Limb::ZERO, carry); + let (a6, carry) = a6.adc(Limb::ZERO, carry); + let (a7, carry) = a7.adc(a1, carry); + let (a8, carry) = a8.mac(a1, modulus[7], carry); + let (a9, carry2) = a9.adc(carry2, carry); + + let (a5, carry) = a5.adc(Limb::ZERO, a2); + let (a6, carry) = a6.adc(Limb::ZERO, carry); + let (a7, carry) = a7.adc(Limb::ZERO, carry); + let (a8, carry) = a8.adc(a2, carry); + let (a9, carry) = a9.mac(a2, modulus[7], carry); + let (a10, carry2) = a10.adc(carry2, carry); + + let (a6, carry) = a6.adc(Limb::ZERO, a3); + let (a7, carry) = a7.adc(Limb::ZERO, carry); + let (a8, carry) = a8.adc(Limb::ZERO, carry); + let (a9, carry) = a9.adc(a3, carry); + let (a10, carry) = a10.mac(a3, modulus[7], carry); + let (a11, carry2) = a11.adc(carry2, carry); + + let (a7, carry) = a7.adc(Limb::ZERO, a4); + let (a8, carry) = a8.adc(Limb::ZERO, carry); + let (a9, carry) = a9.adc(Limb::ZERO, carry); + let (a10, carry) = a10.adc(a4, carry); + let (a11, carry) = a11.mac(a4, modulus[7], carry); + let (a12, carry2) = a12.adc(carry2, carry); + + let (a8, carry) = a8.adc(Limb::ZERO, a5); + let (a9, carry) = a9.adc(Limb::ZERO, carry); + let (a10, carry) = a10.adc(Limb::ZERO, carry); + let (a11, carry) = a11.adc(a5, carry); + let (a12, carry) = a12.mac(a5, modulus[7], carry); + let (a13, carry2) = a13.adc(carry2, carry); + + let (a9, carry) = a9.adc(Limb::ZERO, a6); + let (a10, carry) = a10.adc(Limb::ZERO, carry); + let (a11, carry) = a11.adc(Limb::ZERO, carry); + let (a12, carry) = a12.adc(a6, carry); + let (a13, carry) = a13.mac(a6, modulus[7], carry); + let (a14, carry2) = a14.adc(carry2, carry); + + let (a10, carry) = a10.adc(Limb::ZERO, a7); + let (a11, carry) = a11.adc(Limb::ZERO, carry); + let (a12, carry) = a12.adc(Limb::ZERO, carry); + let (a13, carry) = a13.adc(a7, carry); + let (a14, carry) = a14.mac(a7, modulus[7], carry); + let (a15, a16) = a15.adc(carry2, carry); + + // Result may be within MODULUS of the correct value + let (result, _) = sub_inner( + [a8, a9, a10, a11, a12, a13, a14, a15, a16], + [ + modulus[0], + modulus[1], + modulus[2], + modulus[3], + modulus[4], + modulus[5], + modulus[6], + modulus[7], + Limb::ZERO, + ], + ); + + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) +} + +#[inline] +#[allow(clippy::too_many_arguments)] +const fn sub_inner(l: [Limb; 9], r: [Limb; 9]) -> ([Limb; 8], Limb) { + let (w0, borrow) = l[0].sbb(r[0], Limb::ZERO); + let (w1, borrow) = l[1].sbb(r[1], borrow); + let (w2, borrow) = l[2].sbb(r[2], borrow); + let (w3, borrow) = l[3].sbb(r[3], borrow); + let (w4, borrow) = l[4].sbb(r[4], borrow); + let (w5, borrow) = l[5].sbb(r[5], borrow); + let (w6, borrow) = l[6].sbb(r[6], borrow); + let (w7, borrow) = l[7].sbb(r[7], borrow); + let (_, borrow) = l[8].sbb(r[8], borrow); + + // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise + // borrow = 0x000...000. Thus, we use it as a mask to conditionally add + // the modulus. + + let modulus = MODULUS.0.as_limbs(); + + let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO); + let (w1, carry) = w1.adc(modulus[1].bitand(borrow), carry); + let (w2, carry) = w2.adc(modulus[2].bitand(borrow), carry); + let (w3, carry) = w3.adc(modulus[3].bitand(borrow), carry); + let (w4, carry) = w4.adc(modulus[4].bitand(borrow), carry); + let (w5, carry) = w5.adc(modulus[5].bitand(borrow), carry); + let (w6, carry) = w6.adc(modulus[6].bitand(borrow), carry); + let (w7, _) = w7.adc(modulus[7].bitand(borrow), carry); + + ([w0, w1, w2, w3, w4, w5, w6, w7], borrow) +} diff --git a/p256/src/arithmetic/scalar.rs b/p256/src/arithmetic/scalar.rs index 61a331ec..50357483 100644 --- a/p256/src/arithmetic/scalar.rs +++ b/p256/src/arithmetic/scalar.rs @@ -4,7 +4,9 @@ #[cfg_attr(target_pointer_width = "64", path = "scalar/scalar64.rs")] mod scalar_impl; -use self::scalar_impl::barrett_reduce; +#[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] +use elliptic_curve::bigint::risc0; + use crate::{FieldBytes, NistP256, SecretKey, ORDER_HEX}; use core::{ fmt::{self, Debug}, @@ -38,17 +40,6 @@ pub(crate) const MODULUS: U256 = NistP256::ORDER; /// `MODULUS / 2` const FRAC_MODULUS_2: Scalar = Scalar(MODULUS.shr_vartime(1)); -/// MU = floor(2^512 / n) -/// = 115792089264276142090721624801893421302707618245269942344307673200490803338238 -/// = 0x100000000fffffffffffffffeffffffff43190552df1a6c21012ffd85eedf9bfe -pub const MU: [u64; 5] = [ - 0x012f_fd85_eedf_9bfe, - 0x4319_0552_df1a_6c21, - 0xffff_fffe_ffff_ffff, - 0x0000_0000_ffff_ffff, - 0x0000_0000_0000_0001, -]; - /// Scalars are elements in the finite field modulo n. /// /// # Trait impls @@ -101,8 +92,16 @@ impl Scalar { } /// Returns 2*self. - pub const fn double(&self) -> Self { - self.add(self) + pub fn double(&self) -> Self { + cfg_if::cfg_if! { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + let result = Self(risc0::modmul_u256_denormalized(&self.0, &self.0, &NistP256::ORDER)); + assert!(bool::from(result.0.ct_lt(&NistP256::ORDER))); + result + } else { + self.add(self) + } + } } /// Returns self - rhs mod n. @@ -111,13 +110,21 @@ impl Scalar { } /// Returns self * rhs mod n - pub const fn multiply(&self, rhs: &Self) -> Self { - let (lo, hi) = self.0.mul_wide(&rhs.0); - Self(barrett_reduce(lo, hi)) + pub fn multiply(&self, rhs: &Self) -> Self { + cfg_if::cfg_if! { + if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { + let result = Self(risc0::modmul_u256_denormalized(&self.0, &rhs.0, &NistP256::ORDER)); + assert!(bool::from(result.0.ct_lt(&NistP256::ORDER))); + result + } else { + let (lo, hi) = self.0.mul_wide(&rhs.0); + Self(scalar_impl::barrett_reduce(lo, hi)) + } + } } /// Returns self * self mod p - pub const fn square(&self) -> Self { + pub fn square(&self) -> Self { // Schoolbook multiplication. self.multiply(self) } @@ -137,7 +144,7 @@ impl Scalar { /// Returns the multiplicative inverse of self. /// /// Does not check that self is non-zero. - const fn invert_unchecked(&self) -> Self { + fn invert_unchecked(&self) -> Self { // We need to find b such that b * a ≡ 1 mod p. As we are in a prime // field, we can apply Fermat's Little Theorem: // @@ -158,7 +165,7 @@ impl Scalar { /// Exponentiates `self` by `exp`, where `exp` is a little-endian order integer /// exponent. - pub const fn pow_vartime(&self, exp: &[u64]) -> Self { + pub fn pow_vartime(&self, exp: &[u64]) -> Self { let mut res = Self::ONE; let mut i = exp.len(); @@ -287,13 +294,17 @@ impl PrimeField for Scalar { const MODULUS: &'static str = ORDER_HEX; const NUM_BITS: u32 = 256; const CAPACITY: u32 = 255; - const TWO_INV: Self = Self(U256::from_u8(2)).invert_unchecked(); + const TWO_INV: Self = Self(U256::from_be_hex( + "7FFFFFFF800000007FFFFFFFFFFFFFFFDE737D56D38BCF4279DCE5617E3192A9", + )); const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7)); const S: u32 = 4; const ROOT_OF_UNITY: Self = Self(U256::from_be_hex( "ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602", )); - const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unchecked(); + const ROOT_OF_UNITY_INV: Self = Self(U256::from_be_hex( + "A0A66A5562D46F2AC645FA0458131CAEE3AC117C794C4137379C7F0657C73764", + )); const DELTA: Self = Self(U256::from_u64(33232930569601)); /// Attempts to parse the given byte array as an SEC1-encoded scalar. @@ -363,6 +374,11 @@ impl Invert for Scalar { /// sidechannels. #[allow(non_snake_case)] fn invert_vartime(&self) -> CtOption { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + // Constant time algorithm is faster in the RISC Zero zkVM. + return self.invert(); + } + let mut u = *self; let mut v = Self(MODULUS); let mut A = Self::ONE; @@ -687,7 +703,7 @@ impl ReduceNonZero for Scalar { impl Sum for Scalar { fn sum>(iter: I) -> Self { - iter.reduce(core::ops::Add::add).unwrap_or(Self::ZERO) + iter.reduce(Add::add).unwrap_or(Self::ZERO) } } @@ -699,7 +715,7 @@ impl<'a> Sum<&'a Scalar> for Scalar { impl Product for Scalar { fn product>(iter: I) -> Self { - iter.reduce(core::ops::Mul::mul).unwrap_or(Self::ONE) + iter.reduce(Mul::mul).unwrap_or(Self::ONE) } } @@ -751,7 +767,10 @@ impl<'de> Deserialize<'de> for Scalar { mod tests { use super::Scalar; use crate::{FieldBytes, SecretKey}; - use elliptic_curve::group::ff::{Field, PrimeField}; + use elliptic_curve::{ + bigint::U256, + group::ff::{Field, PrimeField}, + }; use primeorder::{ impl_field_identity_tests, impl_field_invert_tests, impl_field_sqrt_tests, impl_primefield_tests, @@ -780,6 +799,23 @@ mod tests { assert_eq!(bytes, scalar.to_bytes()); } + #[test] + fn root_of_unity_test() { + let root_of_unity_inv = Scalar::ROOT_OF_UNITY.invert_unchecked(); + assert_eq!(root_of_unity_inv, Scalar::ROOT_OF_UNITY_INV); + assert_eq!( + (Scalar::ROOT_OF_UNITY * Scalar::ROOT_OF_UNITY_INV), + Scalar::ONE + ) + } + + #[test] + fn two_inv_test() { + let number = Scalar(U256::from_u8(2)).invert_unchecked(); + assert_eq!(number, Scalar::TWO_INV); + assert_eq!((Scalar::from(2u64) * Scalar::TWO_INV), Scalar::ONE); + } + /// Basic tests that multiplication works. #[test] fn multiply() { diff --git a/p256/src/arithmetic/scalar/scalar32.rs b/p256/src/arithmetic/scalar/scalar32.rs index dfa5742e..01ca75f0 100644 --- a/p256/src/arithmetic/scalar/scalar32.rs +++ b/p256/src/arithmetic/scalar/scalar32.rs @@ -1,12 +1,21 @@ //! 32-bit secp256r1 scalar field algorithms. +use super::MODULUS; +use elliptic_curve::bigint::{Limb, U256}; -// TODO(tarcieri): adapt 64-bit arithmetic to proper 32-bit arithmetic - -use super::{MODULUS, MU}; -use crate::{ - arithmetic::util::{adc, mac, sbb}, - U256, -}; +/// MU = floor(2^512 / n) +/// = 115792089264276142090721624801893421302707618245269942344307673200490803338238 +/// = 0x100000000fffffffffffffffeffffffff43190552df1a6c21012ffd85eedf9bfe +const MU: [Limb; 9] = [ + Limb::from_u32(0xeedf_9bfe), + Limb::from_u32(0x012f_fd85), + Limb::from_u32(0xdf1a_6c21), + Limb::from_u32(0x4319_0552), + Limb::from_u32(0xffff_ffff), + Limb::from_u32(0xffff_fffe), + Limb::from_u32(0xffff_ffff), + Limb::from_u32(0x0000_0000), + Limb::from_u32(0x0000_0001), +]; /// Barrett Reduction /// @@ -39,150 +48,278 @@ use crate::{ #[inline] #[allow(clippy::too_many_arguments)] pub(super) const fn barrett_reduce(lo: U256, hi: U256) -> U256 { - let lo = u256_to_u64x4(lo); - let hi = u256_to_u64x4(hi); + let lo = lo.as_limbs(); + let hi = hi.as_limbs(); + let a0 = lo[0]; let a1 = lo[1]; let a2 = lo[2]; let a3 = lo[3]; - let a4 = hi[0]; - let a5 = hi[1]; - let a6 = hi[2]; - let a7 = hi[3]; - let q1: [u64; 5] = [a3, a4, a5, a6, a7]; - let q3 = q1_times_mu_shift_five(&q1); + let a4 = lo[4]; + let a5 = lo[5]; + let a6 = lo[6]; + let a7 = lo[7]; + let a8 = hi[0]; + let a9 = hi[1]; + let a10 = hi[2]; + let a11 = hi[3]; + let a12 = hi[4]; + let a13 = hi[5]; + let a14 = hi[6]; + let a15 = hi[7]; + + let q1: [Limb; 9] = [a7, a8, a9, a10, a11, a12, a13, a14, a15]; + let q3: [Limb; 9] = q1_times_mu_shift_nine(&q1); - let r1: [u64; 5] = [a0, a1, a2, a3, a4]; - let r2: [u64; 5] = q3_times_n_keep_five(&q3); - let r: [u64; 5] = sub_inner_five(r1, r2); + let r1: [Limb; 9] = [a0, a1, a2, a3, a4, a5, a6, a7, a8]; + let r2: [Limb; 9] = q3_times_n_keep_nine(&q3); + let r: [Limb; 9] = sub_inner_nine(r1, r2); // Result is in range (0, 3*n - 1), // and 90% of the time, no subtraction will be needed. - let r = subtract_n_if_necessary(r[0], r[1], r[2], r[3], r[4]); - let r = subtract_n_if_necessary(r[0], r[1], r[2], r[3], r[4]); - - U256::from_words([ - (r[0] & 0xFFFFFFFF) as u32, - (r[0] >> 32) as u32, - (r[1] & 0xFFFFFFFF) as u32, - (r[1] >> 32) as u32, - (r[2] & 0xFFFFFFFF) as u32, - (r[2] >> 32) as u32, - (r[3] & 0xFFFFFFFF) as u32, - (r[3] >> 32) as u32, - ]) + let r = subtract_n_if_necessary(r); + let r = subtract_n_if_necessary(r); + + U256::new([r[0], r[1], r[2], r[3], r[4], r[5], r[6], r[7]]) } -const fn q1_times_mu_shift_five(q1: &[u64; 5]) -> [u64; 5] { - // Schoolbook multiplication. - - let (_w0, carry) = mac(0, q1[0], MU[0], 0); - let (w1, carry) = mac(0, q1[0], MU[1], carry); - let (w2, carry) = mac(0, q1[0], MU[2], carry); - let (w3, carry) = mac(0, q1[0], MU[3], carry); - let (w4, w5) = mac(0, q1[0], MU[4], carry); - - let (_w1, carry) = mac(w1, q1[1], MU[0], 0); - let (w2, carry) = mac(w2, q1[1], MU[1], carry); - let (w3, carry) = mac(w3, q1[1], MU[2], carry); - let (w4, carry) = mac(w4, q1[1], MU[3], carry); - let (w5, w6) = mac(w5, q1[1], MU[4], carry); - - let (_w2, carry) = mac(w2, q1[2], MU[0], 0); - let (w3, carry) = mac(w3, q1[2], MU[1], carry); - let (w4, carry) = mac(w4, q1[2], MU[2], carry); - let (w5, carry) = mac(w5, q1[2], MU[3], carry); - let (w6, w7) = mac(w6, q1[2], MU[4], carry); - - let (_w3, carry) = mac(w3, q1[3], MU[0], 0); - let (w4, carry) = mac(w4, q1[3], MU[1], carry); - let (w5, carry) = mac(w5, q1[3], MU[2], carry); - let (w6, carry) = mac(w6, q1[3], MU[3], carry); - let (w7, w8) = mac(w7, q1[3], MU[4], carry); - - let (_w4, carry) = mac(w4, q1[4], MU[0], 0); - let (w5, carry) = mac(w5, q1[4], MU[1], carry); - let (w6, carry) = mac(w6, q1[4], MU[2], carry); - let (w7, carry) = mac(w7, q1[4], MU[3], carry); - let (w8, w9) = mac(w8, q1[4], MU[4], carry); - - // let q2 = [_w0, _w1, _w2, _w3, _w4, w5, w6, w7, w8, w9]; - [w5, w6, w7, w8, w9] +const fn q1_times_mu_shift_nine(q1: &[Limb; 9]) -> [Limb; 9] { + // Schoolbook multiplication + + let (_w0, carry) = Limb::ZERO.mac(q1[0], MU[0], Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(q1[0], MU[1], carry); + let (w2, carry) = Limb::ZERO.mac(q1[0], MU[2], carry); + let (w3, carry) = Limb::ZERO.mac(q1[0], MU[3], carry); + let (w4, carry) = Limb::ZERO.mac(q1[0], MU[4], carry); + let (w5, carry) = Limb::ZERO.mac(q1[0], MU[5], carry); + let (w6, carry) = Limb::ZERO.mac(q1[0], MU[6], carry); + // NOTE MU[7] == 0 + // let (w7, carry) = Limb::ZERO.mac(q1[0], MU[7], carry); + let (w7, _carry) = (carry, Limb::ZERO); + // NOTE MU[8] == 1 + // let (w8, w9) = Limb::ZERO.mac(q1[0], MU[8], carry); + let (w8, w9) = (q1[0], Limb::ZERO); + + let (_w1, carry) = w1.mac(q1[1], MU[0], Limb::ZERO); + let (w2, carry) = w2.mac(q1[1], MU[1], carry); + let (w3, carry) = w3.mac(q1[1], MU[2], carry); + let (w4, carry) = w4.mac(q1[1], MU[3], carry); + let (w5, carry) = w5.mac(q1[1], MU[4], carry); + let (w6, carry) = w6.mac(q1[1], MU[5], carry); + let (w7, carry) = w7.mac(q1[1], MU[6], carry); + // NOTE MU[7] == 0 + // let (w8, carry) = w8.mac(q1[1], MU[7], carry); + let (w8, carry) = w8.adc(Limb::ZERO, carry); + // NOTE MU[8] == 1 + // let (w9, w10) = w9.mac(q1[1], MU[8], carry); + let (w9, w10) = w9.adc(q1[1], carry); + + let (_w2, carry) = w2.mac(q1[2], MU[0], Limb::ZERO); + let (w3, carry) = w3.mac(q1[2], MU[1], carry); + let (w4, carry) = w4.mac(q1[2], MU[2], carry); + let (w5, carry) = w5.mac(q1[2], MU[3], carry); + let (w6, carry) = w6.mac(q1[2], MU[4], carry); + let (w7, carry) = w7.mac(q1[2], MU[5], carry); + let (w8, carry) = w8.mac(q1[2], MU[6], carry); + // let (w9, carry) = w9.mac(q1[2], MU[7], carry); + let (w9, carry) = w9.adc(Limb::ZERO, carry); + // let (w10, w11) = w10.mac(q1[2], MU[8], carry); + let (w10, w11) = w10.adc(q1[2], carry); + + let (_w3, carry) = w3.mac(q1[3], MU[0], Limb::ZERO); + let (w4, carry) = w4.mac(q1[3], MU[1], carry); + let (w5, carry) = w5.mac(q1[3], MU[2], carry); + let (w6, carry) = w6.mac(q1[3], MU[3], carry); + let (w7, carry) = w7.mac(q1[3], MU[4], carry); + let (w8, carry) = w8.mac(q1[3], MU[5], carry); + let (w9, carry) = w9.mac(q1[3], MU[6], carry); + // let (w10, carry) = w10.mac(q1[3], MU[7], carry); + let (w10, carry) = w10.adc(Limb::ZERO, carry); + // let (w11, w12) = w11.mac(q1[3], MU[8], carry); + let (w11, w12) = w11.adc(q1[3], carry); + + let (_w4, carry) = w4.mac(q1[4], MU[0], Limb::ZERO); + let (w5, carry) = w5.mac(q1[4], MU[1], carry); + let (w6, carry) = w6.mac(q1[4], MU[2], carry); + let (w7, carry) = w7.mac(q1[4], MU[3], carry); + let (w8, carry) = w8.mac(q1[4], MU[4], carry); + let (w9, carry) = w9.mac(q1[4], MU[5], carry); + let (w10, carry) = w10.mac(q1[4], MU[6], carry); + // let (w11, carry) = w11.mac(q1[4], MU[7], carry); + let (w11, carry) = w11.adc(Limb::ZERO, carry); + // let (w12, w13) = w12.mac(q1[4], MU[8], carry); + let (w12, w13) = w12.adc(q1[4], carry); + + let (_w5, carry) = w5.mac(q1[5], MU[0], Limb::ZERO); + let (w6, carry) = w6.mac(q1[5], MU[1], carry); + let (w7, carry) = w7.mac(q1[5], MU[2], carry); + let (w8, carry) = w8.mac(q1[5], MU[3], carry); + let (w9, carry) = w9.mac(q1[5], MU[4], carry); + let (w10, carry) = w10.mac(q1[5], MU[5], carry); + let (w11, carry) = w11.mac(q1[5], MU[6], carry); + // let (w12, carry) = w12.mac(q1[5], MU[7], carry); + let (w12, carry) = w12.adc(Limb::ZERO, carry); + // let (w13, w14) = w13.mac(q1[5], MU[8], carry); + let (w13, w14) = w13.adc(q1[5], carry); + + let (_w6, carry) = w6.mac(q1[6], MU[0], Limb::ZERO); + let (w7, carry) = w7.mac(q1[6], MU[1], carry); + let (w8, carry) = w8.mac(q1[6], MU[2], carry); + let (w9, carry) = w9.mac(q1[6], MU[3], carry); + let (w10, carry) = w10.mac(q1[6], MU[4], carry); + let (w11, carry) = w11.mac(q1[6], MU[5], carry); + let (w12, carry) = w12.mac(q1[6], MU[6], carry); + // let (w13, carry) = w13.mac(q1[6], MU[7], carry); + let (w13, carry) = w13.adc(Limb::ZERO, carry); + // let (w14, w15) = w14.mac(q1[6], MU[8], carry); + let (w14, w15) = w14.adc(q1[6], carry); + + let (_w7, carry) = w7.mac(q1[7], MU[0], Limb::ZERO); + let (w8, carry) = w8.mac(q1[7], MU[1], carry); + let (w9, carry) = w9.mac(q1[7], MU[2], carry); + let (w10, carry) = w10.mac(q1[7], MU[3], carry); + let (w11, carry) = w11.mac(q1[7], MU[4], carry); + let (w12, carry) = w12.mac(q1[7], MU[5], carry); + let (w13, carry) = w13.mac(q1[7], MU[6], carry); + // let (w14, carry) = w14.mac(q1[7], MU[7], carry); + let (w14, carry) = w14.adc(Limb::ZERO, carry); + // let (w15, w16) = w15.mac(q1[7], MU[8], carry); + let (w15, w16) = w15.adc(q1[7], carry); + + let (_w8, carry) = w8.mac(q1[8], MU[0], Limb::ZERO); + let (w9, carry) = w9.mac(q1[8], MU[1], carry); + let (w10, carry) = w10.mac(q1[8], MU[2], carry); + let (w11, carry) = w11.mac(q1[8], MU[3], carry); + let (w12, carry) = w12.mac(q1[8], MU[4], carry); + let (w13, carry) = w13.mac(q1[8], MU[5], carry); + let (w14, carry) = w14.mac(q1[8], MU[6], carry); + // let (w15, carry) = w15.mac(w15, q1[8], MU[7], carry); + let (w15, carry) = w15.adc(Limb::ZERO, carry); + // let (w16, w17) = w16.mac(w16, q1[8], MU[8], carry); + let (w16, w17) = w16.adc(q1[8], carry); + + // let q2 = [_w0, _w1, _w2, _w3, _w4, _w5, _w6, _w7, _w8, w9, w10, w11, w12, w13, w14, w15, w16, w17]; + [w9, w10, w11, w12, w13, w14, w15, w16, w17] } -const fn q3_times_n_keep_five(q3: &[u64; 5]) -> [u64; 5] { - // Schoolbook multiplication. +const fn q3_times_n_keep_nine(q3: &[Limb; 9]) -> [Limb; 9] { + // Schoolbook multiplication + + let modulus = MODULUS.as_limbs(); + + /* NOTE + * modulus[7] = 2^32 - 1 + * modulus[6] = 0 + * modulus[5] = 2^32 - 1 + * modulus[4] = 2^32 - 1 + */ + + let (w0, carry) = Limb::ZERO.mac(q3[0], modulus[0], Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(q3[0], modulus[1], carry); + let (w2, carry) = Limb::ZERO.mac(q3[0], modulus[2], carry); + let (w3, carry) = Limb::ZERO.mac(q3[0], modulus[3], carry); + let (w4, carry) = Limb::ZERO.mac(q3[0], modulus[4], carry); + let (w5, carry) = Limb::ZERO.mac(q3[0], modulus[5], carry); + // NOTE modulus[6] = 0 + // let (w6, carry) = Limb::ZERO.mac(q3[0], modulus[6], carry); + let (w6, carry) = (carry, Limb::ZERO); + let (w7, carry) = Limb::ZERO.mac(q3[0], modulus[7], carry); + // let (w8, _) = Limb::ZERO.mac(q3[0], Limb::ZERO, carry); + let (w8, _) = (carry, Limb::ZERO); + + let (w1, carry) = w1.mac(q3[1], modulus[0], Limb::ZERO); + let (w2, carry) = w2.mac(q3[1], modulus[1], carry); + let (w3, carry) = w3.mac(q3[1], modulus[2], carry); + let (w4, carry) = w4.mac(q3[1], modulus[3], carry); + let (w5, carry) = w5.mac(q3[1], modulus[4], carry); + let (w6, carry) = w6.mac(q3[1], modulus[5], carry); + // let (w7, carry) = w7.mac(q3[1], modulus[6], carry); + let (w7, carry) = w7.adc(Limb::ZERO, carry); + let (w8, _) = w8.mac(q3[1], modulus[7], carry); - let modulus = u256_to_u64x4(MODULUS); + let (w2, carry) = w2.mac(q3[2], modulus[0], Limb::ZERO); + let (w3, carry) = w3.mac(q3[2], modulus[1], carry); + let (w4, carry) = w4.mac(q3[2], modulus[2], carry); + let (w5, carry) = w5.mac(q3[2], modulus[3], carry); + let (w6, carry) = w6.mac(q3[2], modulus[4], carry); + let (w7, carry) = w7.mac(q3[2], modulus[5], carry); + // let (w8, _) = w8.mac(q3[2], modulus[6], carry); + let (w8, _) = w8.adc(Limb::ZERO, carry); - let (w0, carry) = mac(0, q3[0], modulus[0], 0); - let (w1, carry) = mac(0, q3[0], modulus[1], carry); - let (w2, carry) = mac(0, q3[0], modulus[2], carry); - let (w3, carry) = mac(0, q3[0], modulus[3], carry); - let (w4, _) = mac(0, q3[0], 0, carry); + let (w3, carry) = w3.mac(q3[3], modulus[0], Limb::ZERO); + let (w4, carry) = w4.mac(q3[3], modulus[1], carry); + let (w5, carry) = w5.mac(q3[3], modulus[2], carry); + let (w6, carry) = w6.mac(q3[3], modulus[3], carry); + let (w7, carry) = w7.mac(q3[3], modulus[4], carry); + let (w8, _) = w8.mac(q3[3], modulus[5], carry); - let (w1, carry) = mac(w1, q3[1], modulus[0], 0); - let (w2, carry) = mac(w2, q3[1], modulus[1], carry); - let (w3, carry) = mac(w3, q3[1], modulus[2], carry); - let (w4, _) = mac(w4, q3[1], modulus[3], carry); + let (w4, carry) = w4.mac(q3[4], modulus[0], Limb::ZERO); + let (w5, carry) = w5.mac(q3[4], modulus[1], carry); + let (w6, carry) = w6.mac(q3[4], modulus[2], carry); + let (w7, carry) = w7.mac(q3[4], modulus[3], carry); + let (w8, _) = w8.mac(q3[4], modulus[4], carry); - let (w2, carry) = mac(w2, q3[2], modulus[0], 0); - let (w3, carry) = mac(w3, q3[2], modulus[1], carry); - let (w4, _) = mac(w4, q3[2], modulus[2], carry); + let (w5, carry) = w5.mac(q3[5], modulus[0], Limb::ZERO); + let (w6, carry) = w6.mac(q3[5], modulus[1], carry); + let (w7, carry) = w7.mac(q3[5], modulus[2], carry); + let (w8, _) = w8.mac(q3[5], modulus[3], carry); - let (w3, carry) = mac(w3, q3[3], modulus[0], 0); - let (w4, _) = mac(w4, q3[3], modulus[1], carry); + let (w6, carry) = w6.mac(q3[6], modulus[0], Limb::ZERO); + let (w7, carry) = w7.mac(q3[6], modulus[1], carry); + let (w8, _) = w8.mac(q3[6], modulus[2], carry); - let (w4, _) = mac(w4, q3[4], modulus[0], 0); + let (w7, carry) = w7.mac(q3[7], modulus[0], Limb::ZERO); + let (w8, _) = w8.mac(q3[7], modulus[1], carry); - [w0, w1, w2, w3, w4] + let (w8, _) = w8.mac(q3[8], modulus[0], Limb::ZERO); + + [w0, w1, w2, w3, w4, w5, w6, w7, w8] } #[inline] #[allow(clippy::too_many_arguments)] -const fn sub_inner_five(l: [u64; 5], r: [u64; 5]) -> [u64; 5] { - let (w0, borrow) = sbb(l[0], r[0], 0); - let (w1, borrow) = sbb(l[1], r[1], borrow); - let (w2, borrow) = sbb(l[2], r[2], borrow); - let (w3, borrow) = sbb(l[3], r[3], borrow); - let (w4, _borrow) = sbb(l[4], r[4], borrow); - - // If underflow occurred on the final limb - don't care (= add b^{k+1}). - [w0, w1, w2, w3, w4] +const fn sub_inner_nine(l: [Limb; 9], r: [Limb; 9]) -> [Limb; 9] { + let (w0, borrow) = l[0].sbb(r[0], Limb::ZERO); + let (w1, borrow) = l[1].sbb(r[1], borrow); + let (w2, borrow) = l[2].sbb(r[2], borrow); + let (w3, borrow) = l[3].sbb(r[3], borrow); + let (w4, borrow) = l[4].sbb(r[4], borrow); + let (w5, borrow) = l[5].sbb(r[5], borrow); + let (w6, borrow) = l[6].sbb(r[6], borrow); + let (w7, borrow) = l[7].sbb(r[7], borrow); + let (w8, _borrow) = l[8].sbb(r[8], borrow); + + // If underflow occured in the final limb - don't care (= add b^{k+1}). + [w0, w1, w2, w3, w4, w5, w6, w7, w8] } #[inline] #[allow(clippy::too_many_arguments)] -const fn subtract_n_if_necessary(r0: u64, r1: u64, r2: u64, r3: u64, r4: u64) -> [u64; 5] { - let modulus = u256_to_u64x4(MODULUS); - - let (w0, borrow) = sbb(r0, modulus[0], 0); - let (w1, borrow) = sbb(r1, modulus[1], borrow); - let (w2, borrow) = sbb(r2, modulus[2], borrow); - let (w3, borrow) = sbb(r3, modulus[3], borrow); - let (w4, borrow) = sbb(r4, 0, borrow); - - // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise - // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the - // modulus. - let (w0, carry) = adc(w0, modulus[0] & borrow, 0); - let (w1, carry) = adc(w1, modulus[1] & borrow, carry); - let (w2, carry) = adc(w2, modulus[2] & borrow, carry); - let (w3, carry) = adc(w3, modulus[3] & borrow, carry); - let (w4, _carry) = adc(w4, 0, carry); - - [w0, w1, w2, w3, w4] -} +const fn subtract_n_if_necessary(r: [Limb; 9]) -> [Limb; 9] { + let modulus = MODULUS.as_limbs(); -// TODO(tarcieri): replace this with proper 32-bit arithmetic -#[inline] -const fn u256_to_u64x4(u256: U256) -> [u64; 4] { - let words = u256.as_words(); - - [ - (words[0] as u64) | ((words[1] as u64) << 32), - (words[2] as u64) | ((words[3] as u64) << 32), - (words[4] as u64) | ((words[5] as u64) << 32), - (words[6] as u64) | ((words[7] as u64) << 32), - ] + let (w0, borrow) = r[0].sbb(modulus[0], Limb::ZERO); + let (w1, borrow) = r[1].sbb(modulus[1], borrow); + let (w2, borrow) = r[2].sbb(modulus[2], borrow); + let (w3, borrow) = r[3].sbb(modulus[3], borrow); + let (w4, borrow) = r[4].sbb(modulus[4], borrow); + let (w5, borrow) = r[5].sbb(modulus[5], borrow); + let (w6, borrow) = r[6].sbb(modulus[6], borrow); + let (w7, borrow) = r[7].sbb(modulus[7], borrow); + let (w8, borrow) = r[8].sbb(Limb::ZERO, borrow); + + // If underflow occurred in the final limb, borrow = 0xfff...fff, otherwise + // borrow = 0x000...000. Thus, we use it as a mask to conditionally add + // the modulus. + let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO); + let (w1, carry) = w1.adc(modulus[1].bitand(borrow), carry); + let (w2, carry) = w2.adc(modulus[2].bitand(borrow), carry); + let (w3, carry) = w3.adc(modulus[3].bitand(borrow), carry); + let (w4, carry) = w4.adc(modulus[4].bitand(borrow), carry); + let (w5, carry) = w5.adc(modulus[5].bitand(borrow), carry); + let (w6, carry) = w6.adc(modulus[6].bitand(borrow), carry); + let (w7, carry) = w7.adc(modulus[7].bitand(borrow), carry); + let (w8, _carry) = w8.adc(Limb::ZERO, carry); + + [w0, w1, w2, w3, w4, w5, w6, w7, w8] } diff --git a/p256/src/arithmetic/scalar/scalar64.rs b/p256/src/arithmetic/scalar/scalar64.rs index e15711bc..2bad5612 100644 --- a/p256/src/arithmetic/scalar/scalar64.rs +++ b/p256/src/arithmetic/scalar/scalar64.rs @@ -1,10 +1,18 @@ //! 64-bit secp256r1 scalar field algorithms. -use super::{MODULUS, MU}; -use crate::{ - arithmetic::util::{adc, mac, sbb}, - U256, -}; +use super::MODULUS; +use elliptic_curve::bigint::{Limb, U256}; + +/// MU = floor(2^512 / n) +/// = 115792089264276142090721624801893421302707618245269942344307673200490803338238 +/// = 0x100000000fffffffffffffffeffffffff43190552df1a6c21012ffd85eedf9bfe +const MU: [Limb; 5] = [ + Limb::from_u64(0x012f_fd85_eedf_9bfe), + Limb::from_u64(0x4319_0552_df1a_6c21), + Limb::from_u64(0xffff_fffe_ffff_ffff), + Limb::from_u64(0x0000_0000_ffff_ffff), + Limb::from_u64(0x0000_0000_0000_0001), +]; /// Barrett Reduction /// @@ -37,8 +45,8 @@ use crate::{ #[inline] #[allow(clippy::too_many_arguments)] pub(super) const fn barrett_reduce(lo: U256, hi: U256) -> U256 { - let lo = lo.as_words(); - let hi = hi.as_words(); + let lo = lo.as_limbs(); + let hi = hi.as_limbs(); let a0 = lo[0]; let a1 = lo[1]; let a2 = lo[2]; @@ -47,93 +55,100 @@ pub(super) const fn barrett_reduce(lo: U256, hi: U256) -> U256 { let a5 = hi[1]; let a6 = hi[2]; let a7 = hi[3]; - let q1: [u64; 5] = [a3, a4, a5, a6, a7]; + let q1 = [a3, a4, a5, a6, a7]; let q3 = q1_times_mu_shift_five(&q1); - let r1: [u64; 5] = [a0, a1, a2, a3, a4]; - let r2: [u64; 5] = q3_times_n_keep_five(&q3); - let r: [u64; 5] = sub_inner_five(r1, r2); + let r1 = [a0, a1, a2, a3, a4]; + let r2 = q3_times_n_keep_five(&q3); + let r = sub_inner_five(r1, r2); // Result is in range (0, 3*n - 1), // and 90% of the time, no subtraction will be needed. - let r = subtract_n_if_necessary(r[0], r[1], r[2], r[3], r[4]); - let r = subtract_n_if_necessary(r[0], r[1], r[2], r[3], r[4]); - U256::from_words([r[0], r[1], r[2], r[3]]) + let r = subtract_n_if_necessary(r); + let r = subtract_n_if_necessary(r); + U256::new([r[0], r[1], r[2], r[3]]) } -const fn q1_times_mu_shift_five(q1: &[u64; 5]) -> [u64; 5] { - // Schoolbook multiplication. - - let (_w0, carry) = mac(0, q1[0], MU[0], 0); - let (w1, carry) = mac(0, q1[0], MU[1], carry); - let (w2, carry) = mac(0, q1[0], MU[2], carry); - let (w3, carry) = mac(0, q1[0], MU[3], carry); - let (w4, w5) = mac(0, q1[0], MU[4], carry); - - let (_w1, carry) = mac(w1, q1[1], MU[0], 0); - let (w2, carry) = mac(w2, q1[1], MU[1], carry); - let (w3, carry) = mac(w3, q1[1], MU[2], carry); - let (w4, carry) = mac(w4, q1[1], MU[3], carry); - let (w5, w6) = mac(w5, q1[1], MU[4], carry); - - let (_w2, carry) = mac(w2, q1[2], MU[0], 0); - let (w3, carry) = mac(w3, q1[2], MU[1], carry); - let (w4, carry) = mac(w4, q1[2], MU[2], carry); - let (w5, carry) = mac(w5, q1[2], MU[3], carry); - let (w6, w7) = mac(w6, q1[2], MU[4], carry); - - let (_w3, carry) = mac(w3, q1[3], MU[0], 0); - let (w4, carry) = mac(w4, q1[3], MU[1], carry); - let (w5, carry) = mac(w5, q1[3], MU[2], carry); - let (w6, carry) = mac(w6, q1[3], MU[3], carry); - let (w7, w8) = mac(w7, q1[3], MU[4], carry); - - let (_w4, carry) = mac(w4, q1[4], MU[0], 0); - let (w5, carry) = mac(w5, q1[4], MU[1], carry); - let (w6, carry) = mac(w6, q1[4], MU[2], carry); - let (w7, carry) = mac(w7, q1[4], MU[3], carry); - let (w8, w9) = mac(w8, q1[4], MU[4], carry); +const fn q1_times_mu_shift_five(q1: &[Limb; 5]) -> [Limb; 5] { + // Schoolbook multiplication + + let (_w0, carry) = Limb::ZERO.mac(q1[0], MU[0], Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(q1[0], MU[1], carry); + let (w2, carry) = Limb::ZERO.mac(q1[0], MU[2], carry); + let (w3, carry) = Limb::ZERO.mac(q1[0], MU[3], carry); + // NOTE MU[4] == 1 + // let (w4, w5) = Limb::ZERO.mac(q1[0], MU[4], carry); + let (w4, w5) = Limb::ZERO.adc(q1[0], carry); + + let (_w1, carry) = w1.mac(q1[1], MU[0], Limb::ZERO); + let (w2, carry) = w2.mac(q1[1], MU[1], carry); + let (w3, carry) = w3.mac(q1[1], MU[2], carry); + let (w4, carry) = w4.mac(q1[1], MU[3], carry); + // let (w5, w6) = mac(w5, q1[1], MU[4], carry); + let (w5, w6) = w5.adc(q1[1], carry); + + let (_w2, carry) = w2.mac(q1[2], MU[0], Limb::ZERO); + let (w3, carry) = w3.mac(q1[2], MU[1], carry); + let (w4, carry) = w4.mac(q1[2], MU[2], carry); + let (w5, carry) = w5.mac(q1[2], MU[3], carry); + // let (w6, w7) = w6.mac(q1[2], MU[4], carry); + let (w6, w7) = w6.adc(q1[2], carry); + + let (_w3, carry) = w3.mac(q1[3], MU[0], Limb::ZERO); + let (w4, carry) = w4.mac(q1[3], MU[1], carry); + let (w5, carry) = w5.mac(q1[3], MU[2], carry); + let (w6, carry) = w6.mac(q1[3], MU[3], carry); + // let (w7, w8) = w7.mac(q1[3], MU[4], carry); + let (w7, w8) = w7.adc(q1[3], carry); + + let (_w4, carry) = w4.mac(q1[4], MU[0], Limb::ZERO); + let (w5, carry) = w5.mac(q1[4], MU[1], carry); + let (w6, carry) = w6.mac(q1[4], MU[2], carry); + let (w7, carry) = w7.mac(q1[4], MU[3], carry); + // let (w8, w9) = w8.mac(q1[4], MU[4], carry); + let (w8, w9) = w8.adc(q1[4], carry); // let q2 = [_w0, _w1, _w2, _w3, _w4, w5, w6, w7, w8, w9]; [w5, w6, w7, w8, w9] } -const fn q3_times_n_keep_five(q3: &[u64; 5]) -> [u64; 5] { +const fn q3_times_n_keep_five(q3: &[Limb; 5]) -> [Limb; 5] { // Schoolbook multiplication. - let modulus = MODULUS.as_words(); + let modulus = MODULUS.as_limbs(); - let (w0, carry) = mac(0, q3[0], modulus[0], 0); - let (w1, carry) = mac(0, q3[0], modulus[1], carry); - let (w2, carry) = mac(0, q3[0], modulus[2], carry); - let (w3, carry) = mac(0, q3[0], modulus[3], carry); - let (w4, _) = mac(0, q3[0], 0, carry); + let (w0, carry) = Limb::ZERO.mac(q3[0], modulus[0], Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(q3[0], modulus[1], carry); + let (w2, carry) = Limb::ZERO.mac(q3[0], modulus[2], carry); + let (w3, carry) = Limb::ZERO.mac(q3[0], modulus[3], carry); + // let (w4, _) = Limb::ZERO.mac(q3[0], 0, carry); + let (w4, _) = (carry, Limb::ZERO); - let (w1, carry) = mac(w1, q3[1], modulus[0], 0); - let (w2, carry) = mac(w2, q3[1], modulus[1], carry); - let (w3, carry) = mac(w3, q3[1], modulus[2], carry); - let (w4, _) = mac(w4, q3[1], modulus[3], carry); + let (w1, carry) = w1.mac(q3[1], modulus[0], Limb::ZERO); + let (w2, carry) = w2.mac(q3[1], modulus[1], carry); + let (w3, carry) = w3.mac(q3[1], modulus[2], carry); + let (w4, _) = w4.mac(q3[1], modulus[3], carry); - let (w2, carry) = mac(w2, q3[2], modulus[0], 0); - let (w3, carry) = mac(w3, q3[2], modulus[1], carry); - let (w4, _) = mac(w4, q3[2], modulus[2], carry); + let (w2, carry) = w2.mac(q3[2], modulus[0], Limb::ZERO); + let (w3, carry) = w3.mac(q3[2], modulus[1], carry); + let (w4, _) = w4.mac(q3[2], modulus[2], carry); - let (w3, carry) = mac(w3, q3[3], modulus[0], 0); - let (w4, _) = mac(w4, q3[3], modulus[1], carry); + let (w3, carry) = w3.mac(q3[3], modulus[0], Limb::ZERO); + let (w4, _) = w4.mac(q3[3], modulus[1], carry); - let (w4, _) = mac(w4, q3[4], modulus[0], 0); + let (w4, _) = w4.mac(q3[4], modulus[0], Limb::ZERO); [w0, w1, w2, w3, w4] } #[inline] #[allow(clippy::too_many_arguments)] -const fn sub_inner_five(l: [u64; 5], r: [u64; 5]) -> [u64; 5] { - let (w0, borrow) = sbb(l[0], r[0], 0); - let (w1, borrow) = sbb(l[1], r[1], borrow); - let (w2, borrow) = sbb(l[2], r[2], borrow); - let (w3, borrow) = sbb(l[3], r[3], borrow); - let (w4, _borrow) = sbb(l[4], r[4], borrow); +const fn sub_inner_five(l: [Limb; 5], r: [Limb; 5]) -> [Limb; 5] { + let (w0, borrow) = l[0].sbb(r[0], Limb::ZERO); + let (w1, borrow) = l[1].sbb(r[1], borrow); + let (w2, borrow) = l[2].sbb(r[2], borrow); + let (w3, borrow) = l[3].sbb(r[3], borrow); + let (w4, _borrow) = l[4].sbb(r[4], borrow); // If underflow occurred on the final limb - don't care (= add b^{k+1}). [w0, w1, w2, w3, w4] @@ -141,23 +156,23 @@ const fn sub_inner_five(l: [u64; 5], r: [u64; 5]) -> [u64; 5] { #[inline] #[allow(clippy::too_many_arguments)] -const fn subtract_n_if_necessary(r0: u64, r1: u64, r2: u64, r3: u64, r4: u64) -> [u64; 5] { - let modulus = MODULUS.as_words(); +const fn subtract_n_if_necessary(r: [Limb; 5]) -> [Limb; 5] { + let modulus = MODULUS.as_limbs(); - let (w0, borrow) = sbb(r0, modulus[0], 0); - let (w1, borrow) = sbb(r1, modulus[1], borrow); - let (w2, borrow) = sbb(r2, modulus[2], borrow); - let (w3, borrow) = sbb(r3, modulus[3], borrow); - let (w4, borrow) = sbb(r4, 0, borrow); + let (w0, borrow) = r[0].sbb(modulus[0], Limb::ZERO); + let (w1, borrow) = r[1].sbb(modulus[1], borrow); + let (w2, borrow) = r[2].sbb(modulus[2], borrow); + let (w3, borrow) = r[3].sbb(modulus[3], borrow); + let (w4, borrow) = r[4].sbb(Limb::ZERO, borrow); // If underflow occurred on the final limb, borrow = 0xfff...fff, otherwise // borrow = 0x000...000. Thus, we use it as a mask to conditionally add the // modulus. - let (w0, carry) = adc(w0, modulus[0] & borrow, 0); - let (w1, carry) = adc(w1, modulus[1] & borrow, carry); - let (w2, carry) = adc(w2, modulus[2] & borrow, carry); - let (w3, carry) = adc(w3, modulus[3] & borrow, carry); - let (w4, _carry) = adc(w4, 0, carry); + let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO); + let (w1, carry) = w1.adc(modulus[1].bitand(borrow), carry); + let (w2, carry) = w2.adc(modulus[2].bitand(borrow), carry); + let (w3, carry) = w3.adc(modulus[3].bitand(borrow), carry); + let (w4, _carry) = w4.adc(Limb::ZERO, carry); [w0, w1, w2, w3, w4] } diff --git a/p256/src/arithmetic/util.rs b/p256/src/arithmetic/util.rs deleted file mode 100644 index 8ce5a9db..00000000 --- a/p256/src/arithmetic/util.rs +++ /dev/null @@ -1,72 +0,0 @@ -//! Helper functions. -// TODO(tarcieri): replace these with `crypto-bigint` - -use elliptic_curve::bigint::U256; - -/// Computes `a + b + carry`, returning the result along with the new carry. 64-bit version. -#[inline(always)] -pub(crate) const fn adc(a: u64, b: u64, carry: u64) -> (u64, u64) { - let ret = (a as u128) + (b as u128) + (carry as u128); - (ret as u64, (ret >> 64) as u64) -} - -/// Computes `a - (b + borrow)`, returning the result along with the new borrow. 64-bit version. -#[inline(always)] -pub(crate) const fn sbb(a: u64, b: u64, borrow: u64) -> (u64, u64) { - let ret = (a as u128).wrapping_sub((b as u128) + ((borrow >> 63) as u128)); - (ret as u64, (ret >> 64) as u64) -} - -/// Computes `a + (b * c) + carry`, returning the result along with the new carry. -#[inline(always)] -pub(crate) const fn mac(a: u64, b: u64, c: u64, carry: u64) -> (u64, u64) { - let ret = (a as u128) + ((b as u128) * (c as u128)) + (carry as u128); - (ret as u64, (ret >> 64) as u64) -} - -/// Array containing 4 x 64-bit unsigned integers. -// TODO(tarcieri): replace this entirely with `U256` -pub(crate) type U64x4 = [u64; 4]; - -/// Convert to a [`U64x4`] array. -// TODO(tarcieri): implement all algorithms in terms of `U256`? -#[cfg(target_pointer_width = "32")] -pub(crate) const fn u256_to_u64x4(u256: U256) -> U64x4 { - let limbs = u256.to_words(); - - [ - (limbs[0] as u64) | ((limbs[1] as u64) << 32), - (limbs[2] as u64) | ((limbs[3] as u64) << 32), - (limbs[4] as u64) | ((limbs[5] as u64) << 32), - (limbs[6] as u64) | ((limbs[7] as u64) << 32), - ] -} - -/// Convert to a [`U64x4`] array. -// TODO(tarcieri): implement all algorithms in terms of `U256`? -#[cfg(target_pointer_width = "64")] -pub(crate) const fn u256_to_u64x4(u256: U256) -> U64x4 { - u256.to_words() -} - -/// Convert from a [`U64x4`] array. -#[cfg(target_pointer_width = "32")] -pub(crate) const fn u64x4_to_u256(limbs: U64x4) -> U256 { - U256::from_words([ - (limbs[0] & 0xFFFFFFFF) as u32, - (limbs[0] >> 32) as u32, - (limbs[1] & 0xFFFFFFFF) as u32, - (limbs[1] >> 32) as u32, - (limbs[2] & 0xFFFFFFFF) as u32, - (limbs[2] >> 32) as u32, - (limbs[3] & 0xFFFFFFFF) as u32, - (limbs[3] >> 32) as u32, - ]) -} - -/// Convert from a [`U64x4`] array. -// TODO(tarcieri): implement all algorithms in terms of `U256`? -#[cfg(target_pointer_width = "64")] -pub(crate) const fn u64x4_to_u256(limbs: U64x4) -> U256 { - U256::from_words(limbs) -} From 30cb61dfbe198f6985c0215beac51be86fa38e9c Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Fri, 12 Jul 2024 14:38:23 +0800 Subject: [PATCH 02/10] add projective and affine points for p256 --- p256/Cargo.toml | 11 +- p256/src/arithmetic.rs | 65 +-- p256/src/arithmetic/affine.rs | 371 ++++++++++++++ p256/src/arithmetic/projective.rs | 776 ++++++++++++++++++++++++++++++ p256/src/lib.rs | 5 + 5 files changed, 1182 insertions(+), 46 deletions(-) create mode 100644 p256/src/arithmetic/affine.rs create mode 100644 p256/src/arithmetic/projective.rs diff --git a/p256/Cargo.toml b/p256/Cargo.toml index 97f394c0..8f2530a7 100644 --- a/p256/Cargo.toml +++ b/p256/Cargo.toml @@ -21,6 +21,7 @@ cfg-if = "1.0" elliptic-curve = { version = "0.13.8", default-features = false, features = ["hazmat", "sec1"] } # optional dependencies +once_cell = { version = "1.19", optional = true, default-features = false } ecdsa-core = { version = "0.16", package = "ecdsa", optional = true, default-features = false, features = ["der"] } hex-literal = { version = "0.4", optional = true } primeorder = { version = "0.13", optional = true, path = "../primeorder" } @@ -36,10 +37,18 @@ primeorder = { version = "0.13.5", features = ["dev"], path = "../primeorder" } proptest = "1" rand_core = { version = "0.6", features = ["getrandom"] } +[target.'cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))'.dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } +proptest = "1.4" + +[target.'cfg(all(target_os = "zkvm", target_arch = "riscv32"))'.dev-dependencies] +proptest = { version = "1.4", default-features = false, features = ["alloc"] } +hex = "0.4" + [features] default = ["arithmetic", "ecdsa", "pem", "std"] alloc = ["ecdsa-core?/alloc", "elliptic-curve/alloc", "primeorder?/alloc"] -std = ["alloc", "ecdsa-core?/std", "elliptic-curve/std"] +std = ["alloc", "ecdsa-core?/std", "elliptic-curve/std", "once_cell?/std"] arithmetic = ["dep:primeorder", "elliptic-curve/arithmetic"] bits = ["arithmetic", "elliptic-curve/bits"] diff --git a/p256/src/arithmetic.rs b/p256/src/arithmetic.rs index 4120e93d..7d4aef84 100644 --- a/p256/src/arithmetic.rs +++ b/p256/src/arithmetic.rs @@ -4,21 +4,22 @@ //! //! [NIST SP 800-186]: https://csrc.nist.gov/publications/detail/sp/800-186/final +pub(crate) mod affine; pub(crate) mod field; #[cfg(feature = "hash2curve")] mod hash2curve; +pub(crate) mod projective; pub(crate) mod scalar; use self::{field::FieldElement, scalar::Scalar}; use crate::NistP256; -use elliptic_curve::{bigint::U256, CurveArithmetic, PrimeCurveArithmetic}; -use primeorder::{point_arithmetic, PrimeCurveParams}; +use elliptic_curve::{bigint::U256, CurveArithmetic}; /// Elliptic curve point in affine coordinates. -pub type AffinePoint = primeorder::AffinePoint; +pub type AffinePoint = affine::AffinePoint; /// Elliptic curve point in projective coordinates. -pub type ProjectivePoint = primeorder::ProjectivePoint; +pub type ProjectivePoint = projective::ProjectivePoint; impl CurveArithmetic for NistP256 { type AffinePoint = AffinePoint; @@ -26,54 +27,27 @@ impl CurveArithmetic for NistP256 { type Scalar = Scalar; } -impl PrimeCurveArithmetic for NistP256 { - type CurveGroup = ProjectivePoint; -} - -/// Adapted from [NIST SP 800-186] § G.1.2: Curve P-256. -/// -/// [NIST SP 800-186]: https://csrc.nist.gov/publications/detail/sp/800-186/final -impl PrimeCurveParams for NistP256 { - type FieldElement = FieldElement; - type PointArithmetic = point_arithmetic::EquationAIsMinusThree; - - /// a = -3 - const EQUATION_A: FieldElement = FieldElement(U256::from_be_hex( - "FFFFFFFC00000004000000000000000000000003FFFFFFFFFFFFFFFFFFFFFFFC", - )); +/// a = -3 +const CURVE_EQUATION_A: FieldElement = FieldElement(U256::from_be_hex( + "FFFFFFFC00000004000000000000000000000003FFFFFFFFFFFFFFFFFFFFFFFC", +)); - const EQUATION_B: FieldElement = FieldElement(U256::from_be_hex( - "DC30061D04874834E5A220ABF7212ED6ACF005CD78843090D89CDF6229C4BDDF", - )); - - /// Base point of P-256. - /// - /// Defined in NIST SP 800-186 § G.1.2: - /// - /// ```text - /// Gₓ = 6b17d1f2 e12c4247 f8bce6e5 63a440f2 77037d81 2deb33a0 f4a13945 d898c296 - /// Gᵧ = 4fe342e2 fe1a7f9b 8ee7eb4a 7c0f9e16 2bce3357 6b315ece cbb64068 37bf51f5 - /// ``` - const GENERATOR: (FieldElement, FieldElement) = ( - FieldElement(U256::from_be_hex( - "18905F76A53755C679FB732B7762251075BA95FC5FEDB60179E730D418A9143C", - )), - FieldElement(U256::from_be_hex( - "8571FF1825885D85D2E88688DD21F3258B4AB8E4BA19E45CDDF25357CE95560A", - )), - ); -} +const CURVE_EQUATION_B: FieldElement = FieldElement(U256::from_be_hex( + "DC30061D04874834E5A220ABF7212ED6ACF005CD78843090D89CDF6229C4BDDF", +)); #[cfg(test)] mod tests { use super::FieldElement; - use crate::NistP256; - use primeorder::PrimeCurveParams; + use crate::{ + arithmetic::{CURVE_EQUATION_A, CURVE_EQUATION_B}, + AffinePoint, + }; #[test] fn equation_a_constant() { let equation_a = FieldElement::from_u64(3).neg(); - assert_eq!(equation_a, NistP256::EQUATION_A); + assert_eq!(equation_a, CURVE_EQUATION_A); } #[test] @@ -81,7 +55,7 @@ mod tests { let equation_b = FieldElement::from_hex( "5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", ); - assert_eq!(equation_b, NistP256::EQUATION_B); + assert_eq!(equation_b, CURVE_EQUATION_B); } #[test] @@ -94,6 +68,7 @@ mod tests { "4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", ), ); - assert_eq!(generator, NistP256::GENERATOR); + assert_eq!(generator.0, AffinePoint::GENERATOR.x); + assert_eq!(generator.1, AffinePoint::GENERATOR.y); } } diff --git a/p256/src/arithmetic/affine.rs b/p256/src/arithmetic/affine.rs new file mode 100644 index 00000000..40b104fc --- /dev/null +++ b/p256/src/arithmetic/affine.rs @@ -0,0 +1,371 @@ +//! Affine curve points. + +#![allow(clippy::op_ref)] + +use super::{FieldElement, ProjectivePoint, CURVE_EQUATION_A, CURVE_EQUATION_B}; +use crate::{CompressedPoint, EncodedPoint, FieldBytes, NistP256, PublicKey, Scalar}; +use core::ops::{Mul, Neg}; +use elliptic_curve::{ + bigint::U256, + ff::PrimeField, + group::{prime::PrimeCurveAffine, GroupEncoding}, + point::{AffineCoordinates, DecompactPoint, DecompressPoint}, + sec1::{self, FromEncodedPoint, ToCompactEncodedPoint, ToEncodedPoint}, + subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, CtOption}, + zeroize::DefaultIsZeroes, + Error, FieldBytesEncoding, Result, +}; + +#[cfg(feature = "serde")] +use serdect::serde::{de, ser, Deserialize, Serialize}; + +/// Point on a Weierstrass curve in affine coordinates. +#[derive(Clone, Copy, Debug)] +pub struct AffinePoint { + /// x-coordinate + pub(crate) x: FieldElement, + + /// y-coordinate + pub(crate) y: FieldElement, + + /// Is this point the point at infinity? 0 = no, 1 = yes + /// + /// This is a proxy for [`Choice`], but uses `u8` instead to permit `const` + /// constructors for `IDENTITY` and `GENERATOR`. + pub(crate) infinity: u8, +} + +impl AffinePoint { + /// Additive identity of the group a.k.a. the point at infinity. + pub const IDENTITY: Self = Self { + x: FieldElement::ZERO, + y: FieldElement::ZERO, + infinity: 1, + }; + + /// Base point of the curve. + pub const GENERATOR: Self = Self { + x: FieldElement(U256::from_be_hex( + "18905F76A53755C679FB732B7762251075BA95FC5FEDB60179E730D418A9143C", + )), + y: FieldElement(U256::from_be_hex( + "8571FF1825885D85D2E88688DD21F3258B4AB8E4BA19E45CDDF25357CE95560A", + )), + infinity: 0, + }; + + /// Is this point the point at infinity? + pub fn is_identity(&self) -> Choice { + Choice::from(self.infinity) + } + + /// Conditionally negate [`AffinePoint`] for use with point compaction. + fn to_compact(self) -> Self { + let neg_self = -self; + let choice = >::decode_field_bytes(&self.y.to_repr()) + .ct_gt(&>::decode_field_bytes( + &neg_self.y.to_repr(), + )); + + Self { + x: self.x, + y: FieldElement::conditional_select(&self.y, &neg_self.y, choice), + infinity: self.infinity, + } + } +} + +impl AffineCoordinates for AffinePoint { + type FieldRepr = FieldBytes; + + fn x(&self) -> FieldBytes { + self.x.to_repr() + } + + fn y_is_odd(&self) -> Choice { + self.y.is_odd() + } +} + +impl ConditionallySelectable for AffinePoint { + #[inline(always)] + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self { + x: FieldElement::conditional_select(&a.x, &b.x, choice), + y: FieldElement::conditional_select(&a.y, &b.y, choice), + infinity: u8::conditional_select(&a.infinity, &b.infinity, choice), + } + } +} + +impl ConstantTimeEq for AffinePoint { + fn ct_eq(&self, other: &Self) -> Choice { + self.x.ct_eq(&other.x) & self.y.ct_eq(&other.y) & self.infinity.ct_eq(&other.infinity) + } +} + +impl Default for AffinePoint { + fn default() -> Self { + Self::IDENTITY + } +} + +impl DefaultIsZeroes for AffinePoint {} + +impl DecompressPoint for AffinePoint { + fn decompress(x_bytes: &FieldBytes, y_is_odd: Choice) -> CtOption { + FieldElement::from_repr(*x_bytes).and_then(|x| { + let alpha = x * &x * &x + &(CURVE_EQUATION_A * &x) + &CURVE_EQUATION_B; + let beta = alpha.sqrt(); + + beta.map(|beta| { + let y = + FieldElement::conditional_select(&-beta, &beta, beta.is_odd().ct_eq(&y_is_odd)); + + Self { x, y, infinity: 0 } + }) + }) + } +} + +impl DecompactPoint for AffinePoint { + fn decompact(x_bytes: &FieldBytes) -> CtOption { + Self::decompress(x_bytes, Choice::from(0)).map(|point| point.to_compact()) + } +} + +impl FromEncodedPoint for AffinePoint { + /// Attempts to parse the given [`EncodedPoint`] as an SEC1-encoded + /// [`AffinePoint`]. + /// + /// # Returns + /// + /// `None` value if `encoded_point` is not on the secp384r1 curve. + fn from_encoded_point(encoded_point: &EncodedPoint) -> CtOption { + match encoded_point.coordinates() { + sec1::Coordinates::Identity => CtOption::new(Self::IDENTITY, 1.into()), + sec1::Coordinates::Compact { x } => Self::decompact(x), + sec1::Coordinates::Compressed { x, y_is_odd } => { + Self::decompress(x, Choice::from(y_is_odd as u8)) + } + sec1::Coordinates::Uncompressed { x, y } => FieldElement::from_repr(*y).and_then(|y| { + FieldElement::from_repr(*x).and_then(|x| { + let lhs = y * &y; + let rhs = x * &x * &x + &(CURVE_EQUATION_A * &x) + &CURVE_EQUATION_B; + CtOption::new(Self { x, y, infinity: 0 }, lhs.ct_eq(&rhs)) + }) + }), + } + } +} + +impl Eq for AffinePoint {} + +impl From for AffinePoint { + fn from(p: ProjectivePoint) -> AffinePoint { + p.to_affine() + } +} + +impl From<&ProjectivePoint> for AffinePoint { + fn from(p: &ProjectivePoint) -> AffinePoint { + p.to_affine() + } +} + +impl From for AffinePoint { + fn from(public_key: PublicKey) -> AffinePoint { + *public_key.as_affine() + } +} + +impl From<&PublicKey> for AffinePoint { + fn from(public_key: &PublicKey) -> AffinePoint { + AffinePoint::from(*public_key) + } +} + +impl From for EncodedPoint { + fn from(affine: AffinePoint) -> EncodedPoint { + affine.to_encoded_point(false) + } +} + +impl GroupEncoding for AffinePoint { + type Repr = CompressedPoint; + + /// NOTE: not constant-time with respect to identity point + fn from_bytes(bytes: &Self::Repr) -> CtOption { + EncodedPoint::from_bytes(bytes) + .map(|point| CtOption::new(point, Choice::from(1))) + .unwrap_or_else(|_| { + // SEC1 identity encoding is technically 1-byte 0x00, but the + // `GroupEncoding` API requires a fixed-width `Repr` + let is_identity = bytes.ct_eq(&Self::Repr::default()); + CtOption::new(EncodedPoint::identity(), is_identity) + }) + .and_then(|point| Self::from_encoded_point(&point)) + } + + fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption { + // No unchecked conversion possible for compressed points + Self::from_bytes(bytes) + } + + fn to_bytes(&self) -> Self::Repr { + let encoded = self.to_encoded_point(true); + let mut result = CompressedPoint::default(); + result[..encoded.len()].copy_from_slice(encoded.as_bytes()); + result + } +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl PrimeCurveAffine for AffinePoint { + type Curve = ProjectivePoint; + type Scalar = Scalar; + + fn identity() -> AffinePoint { + Self::IDENTITY + } + + fn generator() -> AffinePoint { + Self::GENERATOR + } + + fn is_identity(&self) -> Choice { + self.is_identity() + } + + fn to_curve(&self) -> ProjectivePoint { + ProjectivePoint::from(*self) + } +} + +impl ToCompactEncodedPoint for AffinePoint { + /// Serialize this value as a SEC1 compact [`EncodedPoint`] + fn to_compact_encoded_point(&self) -> CtOption { + let point = self.to_compact(); + + let mut bytes = CompressedPoint::default(); + bytes[0] = sec1::Tag::Compact.into(); + bytes[1..].copy_from_slice(&point.x.to_repr()); + + let encoded = EncodedPoint::from_bytes(bytes); + let is_some = point.y.ct_eq(&self.y); + CtOption::new(encoded.unwrap_or_default(), is_some) + } +} + +impl ToEncodedPoint for AffinePoint { + fn to_encoded_point(&self, compress: bool) -> EncodedPoint { + EncodedPoint::conditional_select( + &EncodedPoint::from_affine_coordinates(&self.x.to_repr(), &self.y.to_repr(), compress), + &EncodedPoint::identity(), + self.is_identity(), + ) + } +} + +impl TryFrom for AffinePoint { + type Error = Error; + + fn try_from(point: EncodedPoint) -> Result { + AffinePoint::try_from(&point) + } +} + +impl TryFrom<&EncodedPoint> for AffinePoint { + type Error = Error; + + fn try_from(point: &EncodedPoint) -> Result { + Option::from(AffinePoint::from_encoded_point(point)).ok_or(Error) + } +} + +impl TryFrom for PublicKey { + type Error = Error; + + fn try_from(affine_point: AffinePoint) -> Result { + PublicKey::from_affine(affine_point) + } +} + +impl TryFrom<&AffinePoint> for PublicKey { + type Error = Error; + + fn try_from(affine_point: &AffinePoint) -> Result { + PublicKey::try_from(*affine_point) + } +} + +// +// Arithmetic trait impls +// + +impl Mul for AffinePoint { + type Output = ProjectivePoint; + + fn mul(self, scalar: Scalar) -> ProjectivePoint { + ProjectivePoint::from(self) * scalar + } +} + +impl Mul<&Scalar> for AffinePoint { + type Output = ProjectivePoint; + + fn mul(self, scalar: &Scalar) -> ProjectivePoint { + ProjectivePoint::from(self) * *scalar + } +} + +impl Neg for AffinePoint { + type Output = Self; + + fn neg(self) -> Self { + AffinePoint { + x: self.x, + y: -self.y, + infinity: self.infinity, + } + } +} + +impl Neg for &AffinePoint { + type Output = AffinePoint; + + fn neg(self) -> AffinePoint { + -(*self) + } +} + +// +// serde support +// + +#[cfg(feature = "serde")] +impl Serialize for AffinePoint { + fn serialize(&self, serializer: S) -> core::result::Result + where + S: ser::Serializer, + { + self.to_encoded_point(true).serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'de> Deserialize<'de> for AffinePoint { + fn deserialize(deserializer: D) -> core::result::Result + where + D: de::Deserializer<'de>, + { + EncodedPoint::deserialize(deserializer)? + .try_into() + .map_err(de::Error::custom) + } +} diff --git a/p256/src/arithmetic/projective.rs b/p256/src/arithmetic/projective.rs new file mode 100644 index 00000000..86edb616 --- /dev/null +++ b/p256/src/arithmetic/projective.rs @@ -0,0 +1,776 @@ +//! Projective curve points. + +#![allow(clippy::needless_range_loop, clippy::op_ref)] + +use super::{AffinePoint, FieldElement}; +use crate::{ + arithmetic::{CURVE_EQUATION_A, CURVE_EQUATION_B}, + CompressedPoint, EncodedPoint, NistP256, PublicKey, Scalar, +}; +use core::{ + borrow::Borrow, + iter::Sum, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; +use elliptic_curve::{ + bigint::{ArrayEncoding, U256}, + generic_array::ArrayLength, + group::{ + self, + cofactor::CofactorGroup, + ff::Field, + prime::{PrimeCurve, PrimeGroup}, + Group, GroupEncoding, + }, + ops::{BatchInvert, LinearCombination, MulByGenerator}, + point::Double, + rand_core::RngCore, + sec1::{FromEncodedPoint, ModulusSize, ToEncodedPoint, UncompressedPointSize}, + subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}, + zeroize::DefaultIsZeroes, + BatchNormalize, Error, FieldBytes, FieldBytesSize, Result, +}; + +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Point on a Weierstrass curve in projective coordinates. +#[derive(Clone, Copy, Debug)] +pub struct ProjectivePoint { + pub(crate) x: FieldElement, + pub(crate) y: FieldElement, + pub(crate) z: FieldElement, +} + +impl ProjectivePoint { + /// Additive identity of the group a.k.a. the point at infinity. + pub const IDENTITY: Self = Self { + x: FieldElement::ZERO, + y: FieldElement::ONE, + z: FieldElement::ZERO, + }; + + /// Base point of the curve. + pub const GENERATOR: Self = Self { + x: AffinePoint::GENERATOR.x, + y: AffinePoint::GENERATOR.y, + z: FieldElement::ONE, + }; + + /// Returns the affine representation of this point, or `None` if it is the identity. + pub fn to_affine(&self) -> AffinePoint { + ::invert(&self.z) + .map(|zinv| self.to_affine_internal(zinv)) + .unwrap_or(AffinePoint::IDENTITY) + } + + pub(super) fn to_affine_internal(self, zinv: FieldElement) -> AffinePoint { + AffinePoint { + x: self.x * &zinv, + y: self.y * &zinv, + infinity: 0, + } + } + + /// Returns `-self`. + pub fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + z: self.z, + } + } + + /// Returns `self + other`. + /// Implements complete addition for curves with `a = -3` + /// + /// Implements the complete addition formula from [Renes-Costello-Batina 2015] + /// (Algorithm 4). The comments after each line indicate which algorithm steps + /// are being performed. + /// + /// [Renes-Costello-Batina 2015]: https://eprint.iacr.org/2015/1060 + pub fn add(&self, other: &ProjectivePoint) -> ProjectivePoint { + debug_assert_eq!( + CURVE_EQUATION_A, + -FieldElement::from(3), + "this implementation is only valid for C::EQUATION_A = -3" + ); + + let xx = self.x * other.x; // 1 + let yy = self.y * other.y; // 2 + let zz = self.z * other.z; // 3 + let xy_pairs = ((self.x + self.y) * (other.x + other.y)) - (xx + yy); // 4, 5, 6, 7, 8 + let yz_pairs = ((self.y + self.z) * (other.y + other.z)) - (yy + zz); // 9, 10, 11, 12, 13 + let xz_pairs = ((self.x + self.z) * (other.x + other.z)) - (xx + zz); // 14, 15, 16, 17, 18 + + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + let bzz_part = xz_pairs - zz.mul(CURVE_EQUATION_B); // 19, 20 + let bzz3_part = bzz_part.mul_single(3); // 21, 22 + + let yy_m_bzz3 = yy - bzz3_part; // 23 + let yy_p_bzz3 = yy + bzz3_part; // 24 + + let zz3 = zz.mul_single(3); // 26, 27 + let bxz_part = xz_pairs.mul(CURVE_EQUATION_B) - (zz3 + xx); // 25, 28, 29 + let bxz3_part = bxz_part.mul_single(3); // 30, 31 + let xx3_m_zz3 = xx.mul_single(3) - zz3; // 32, 33, 34 + + return ProjectivePoint { + x: (yy_p_bzz3.mul(xy_pairs)) - (yz_pairs.mul(bxz3_part)), // 35, 39, 40 + y: (yy_p_bzz3.mul(yy_m_bzz3)) + (xx3_m_zz3.mul(bxz3_part)), // 36, 37, 38 + z: (yy_m_bzz3.mul(yz_pairs)) + (xy_pairs.mul(xx3_m_zz3)), // 41, 42, 43 + }; + } + + let bzz_part = xz_pairs - (CURVE_EQUATION_B * zz); // 19, 20 + let bzz3_part = bzz_part.double() + bzz_part; // 21, 22 + let yy_m_bzz3 = yy - bzz3_part; // 23 + let yy_p_bzz3 = yy + bzz3_part; // 24 + + let zz3 = zz.double() + zz; // 26, 27 + let bxz_part = (CURVE_EQUATION_B * xz_pairs) - (zz3 + xx); // 25, 28, 29 + let bxz3_part = bxz_part.double() + bxz_part; // 30, 31 + let xx3_m_zz3 = xx.double() + xx - zz3; // 32, 33, 34 + + ProjectivePoint { + x: (yy_p_bzz3 * xy_pairs) - (yz_pairs * bxz3_part), // 35, 39, 40 + y: (yy_p_bzz3 * yy_m_bzz3) + (xx3_m_zz3 * bxz3_part), // 36, 37, 38 + z: (yy_m_bzz3 * yz_pairs) + (xy_pairs * xx3_m_zz3), // 41, 42, 43 + } + } + + /// Returns `self + other`. + /// Implements complete mixed addition for curves with `a = -3` + /// + /// Implements the complete mixed addition formula from [Renes-Costello-Batina 2015] + /// (Algorithm 5). The comments after each line indicate which algorithm + /// steps are being performed. + /// + /// [Renes-Costello-Batina 2015]: https://eprint.iacr.org/2015/1060 + fn add_mixed(&self, other: &AffinePoint) -> ProjectivePoint { + debug_assert_eq!( + CURVE_EQUATION_A, + -FieldElement::from(3), + "this implementation is only valid for C::EQUATION_A = -3" + ); + + let xx = self.x * other.x; // 1 + let yy = self.y * other.y; // 2 + let xy_pairs = ((self.x + self.y) * (other.x + other.y)) - (xx + yy); // 3, 4, 5, 6, 7 + let yz_pairs = (other.y * self.z) + self.y; // 8, 9 (t4) + let xz_pairs = (other.x * self.z) + self.x; // 10, 11 (y3) + + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + let bz_part = xz_pairs - self.z.mul(CURVE_EQUATION_B); // 12, 13 + let bz3_part = bz_part.mul_single(3); // 14, 15 + let yy_m_bzz3 = yy - bz3_part; // 16 + let yy_p_bzz3 = yy + bz3_part; // 17 + + let z3 = self.z.mul_single(3); // 19, 20 + let bxz_part = xz_pairs.mul(CURVE_EQUATION_B) - (z3 + xx); // 18, 21, 22 + let bxz3_part = bxz_part.mul_single(3); // 23, 24 + let xx3_m_zz3 = xx.mul_single(3) - z3; // 25, 26, 27 + + let mut ret = ProjectivePoint { + x: (yy_p_bzz3.mul(xy_pairs)) - (yz_pairs.mul(bxz3_part)), // 28, 32, 33 + y: (yy_p_bzz3.mul(yy_m_bzz3)) + (xx3_m_zz3.mul(bxz3_part)), // 29, 30, 31 + z: (yy_m_bzz3.mul(yz_pairs)) + (xy_pairs.mul(xx3_m_zz3)), // 34, 35, 36 + }; + ret.conditional_assign(self, other.is_identity()); + return ret; + } + + let bz_part = xz_pairs - (CURVE_EQUATION_B * self.z); // 12, 13 + let bz3_part = bz_part.double() + bz_part; // 14, 15 + let yy_m_bzz3 = yy - bz3_part; // 16 + let yy_p_bzz3 = yy + bz3_part; // 17 + + let z3 = self.z.double() + self.z; // 19, 20 + let bxz_part = (CURVE_EQUATION_B * xz_pairs) - (z3 + xx); // 18, 21, 22 + let bxz3_part = bxz_part.double() + bxz_part; // 23, 24 + let xx3_m_zz3 = xx.double() + xx - z3; // 25, 26, 27 + + let mut ret = ProjectivePoint { + x: (yy_p_bzz3 * xy_pairs) - (yz_pairs * bxz3_part), // 28, 32, 33 + y: (yy_p_bzz3 * yy_m_bzz3) + (xx3_m_zz3 * bxz3_part), // 29, 30, 31 + z: (yy_m_bzz3 * yz_pairs) + (xy_pairs * xx3_m_zz3), // 34, 35, 36 + }; + ret.conditional_assign(self, other.is_identity()); + ret + } + + /// Returns `self - other`. + pub fn sub(&self, other: &Self) -> Self { + self.add(&other.neg()) + } + + /// Returns `self - other`. + fn sub_mixed(&self, other: &AffinePoint) -> Self { + self.add_mixed(&other.neg()) + } + + /// Returns `[k] self`. + fn mul(&self, k: &Scalar) -> Self { + // Into::into(*k) -> Uint for NIST P256 is U256 + let k = Into::::into(*k).to_le_byte_array(); + + let mut pc = [Self::default(); 16]; + pc[0] = Self::IDENTITY; + pc[1] = *self; + + for i in 2..16 { + pc[i] = if i % 2 == 0 { + Double::double(&pc[i / 2]) + } else { + pc[i - 1].add(self) + }; + } + + let mut q = Self::IDENTITY; + let mut pos = U256::BITS - 4; + + loop { + let slot = (k[pos >> 3] >> (pos & 7)) & 0xf; + + let mut t = ProjectivePoint::IDENTITY; + + for i in 1..16 { + t.conditional_assign( + &pc[i], + Choice::from(((slot as usize ^ i).wrapping_sub(1) >> 8) as u8 & 1), + ); + } + + q = q.add(&t); + + if pos == 0 { + break; + } + + q = Double::double(&Double::double(&Double::double(&Double::double(&q)))); + pos -= 4; + } + + q + } +} + +impl CofactorGroup for ProjectivePoint { + type Subgroup = Self; + + fn clear_cofactor(&self) -> Self::Subgroup { + *self + } + + fn into_subgroup(self) -> CtOption { + CtOption::new(self, 1.into()) + } + + fn is_torsion_free(&self) -> Choice { + 1.into() + } +} + +impl ConditionallySelectable for ProjectivePoint { + #[inline(always)] + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self { + x: FieldElement::conditional_select(&a.x, &b.x, choice), + y: FieldElement::conditional_select(&a.y, &b.y, choice), + z: FieldElement::conditional_select(&a.z, &b.z, choice), + } + } +} + +impl ConstantTimeEq for ProjectivePoint { + fn ct_eq(&self, other: &Self) -> Choice { + self.to_affine().ct_eq(&other.to_affine()) + } +} + +impl Default for ProjectivePoint { + fn default() -> Self { + Self::IDENTITY + } +} + +impl DefaultIsZeroes for ProjectivePoint {} + +impl Double for ProjectivePoint { + /// Implements point doubling for curves with `a = -3` + /// + /// Implements the exception-free point doubling formula from [Renes-Costello-Batina 2015] + /// (Algorithm 6). The comments after each line indicate which algorithm + /// steps are being performed. + /// + /// [Renes-Costello-Batina 2015]: https://eprint.iacr.org/2015/1060 + fn double(&self) -> ProjectivePoint { + debug_assert_eq!( + CURVE_EQUATION_A, + -FieldElement::from(3), + "this implementation is only valid for C::EQUATION_A = -3" + ); + + let xx = self.x.square(); // 1 + let yy = self.y.square(); // 2 + let zz = self.z.square(); // 3 + let xy2 = (self.x * self.y).double(); // 4, 5 + let xz2 = (self.x * self.z).double(); // 6, 7 + + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + let bzz_part = zz.mul(CURVE_EQUATION_B) - xz2; // 8, 9 + let bzz3_part = bzz_part.mul_single(3); // 10, 11 + let yy_m_bzz3 = yy - bzz3_part; // 12 + let yy_p_bzz3 = yy + bzz3_part; // 13 + let y_frag = yy_p_bzz3.mul(yy_m_bzz3); // 14 + let x_frag = yy_m_bzz3.mul(xy2); // 15 + + let zz3 = zz.mul_single(3); // 16, 17 + let bxz2_part = xz2.mul(CURVE_EQUATION_B) - (zz3 + xx); // 18, 19, 20 + let bxz6_part = bxz2_part.mul_single(3); // 21, 22 + let xx3_m_zz3 = xx.mul_single(3) - zz3; // 23, 24, 25 + + let y = y_frag + (xx3_m_zz3.mul(bxz6_part)); // 26, 27 + let yz2 = (self.y.mul(self.z)).double(); + let x = x_frag - (bxz6_part.mul(yz2)); // 30, 31 + let z = (yz2.mul(yy)).mul_single(4); // 32, 33, 34 + + return ProjectivePoint { x, y, z }; + } + + let bzz_part = (CURVE_EQUATION_B * zz) - xz2; // 8, 9 + let bzz3_part = bzz_part.double() + bzz_part; // 10, 11 + let yy_m_bzz3 = yy - bzz3_part; // 12 + let yy_p_bzz3 = yy + bzz3_part; // 13 + let y_frag = yy_p_bzz3 * yy_m_bzz3; // 14 + let x_frag = yy_m_bzz3 * xy2; // 15 + + let zz3 = zz.double() + zz; // 16, 17 + let bxz2_part = (CURVE_EQUATION_B * xz2) - (zz3 + xx); // 18, 19, 20 + let bxz6_part = bxz2_part.double() + bxz2_part; // 21, 22 + let xx3_m_zz3 = xx.double() + xx - zz3; // 23, 24, 25 + + let y = y_frag + (xx3_m_zz3 * bxz6_part); // 26, 27 + let yz2 = (self.y * self.z).double(); // 28, 29 + let x = x_frag - (bxz6_part * yz2); // 30, 31 + let z = (yz2 * yy).double().double(); // 32, 33, 34 + + ProjectivePoint { x, y, z } + } +} + +impl Eq for ProjectivePoint {} + +impl From for ProjectivePoint { + fn from(p: AffinePoint) -> Self { + let projective = ProjectivePoint { + x: p.x, + y: p.y, + z: FieldElement::ONE, + }; + Self::conditional_select(&projective, &Self::IDENTITY, p.is_identity()) + } +} + +impl From<&AffinePoint> for ProjectivePoint { + fn from(p: &AffinePoint) -> Self { + Self::from(*p) + } +} + +impl From for ProjectivePoint { + fn from(public_key: PublicKey) -> ProjectivePoint { + AffinePoint::from(public_key).into() + } +} + +impl From<&PublicKey> for ProjectivePoint { + fn from(public_key: &PublicKey) -> ProjectivePoint { + AffinePoint::from(public_key).into() + } +} + +impl FromEncodedPoint for ProjectivePoint { + fn from_encoded_point(p: &EncodedPoint) -> CtOption { + AffinePoint::from_encoded_point(p).map(Self::from) + } +} + +impl Group for ProjectivePoint { + type Scalar = Scalar; + + fn random(mut rng: impl RngCore) -> Self { + Self::GENERATOR * ::random(&mut rng) + } + + fn identity() -> Self { + Self::IDENTITY + } + + fn generator() -> Self { + Self::GENERATOR + } + + fn is_identity(&self) -> Choice { + self.ct_eq(&Self::IDENTITY) + } + + #[must_use] + fn double(&self) -> Self { + Double::double(self) + } +} + +impl GroupEncoding for ProjectivePoint { + type Repr = CompressedPoint; + + fn from_bytes(bytes: &Self::Repr) -> CtOption { + ::from_bytes(bytes).map(Into::into) + } + + fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption { + // No unchecked conversion possible for compressed points + Self::from_bytes(bytes) + } + + fn to_bytes(&self) -> Self::Repr { + self.to_affine().to_bytes() + } +} + +impl group::Curve for ProjectivePoint { + type AffineRepr = AffinePoint; + + fn to_affine(&self) -> AffinePoint { + ProjectivePoint::to_affine(self) + } + + // TODO(tarcieri): re-enable when we can add `Invert` bounds on `FieldElement` + // #[cfg(feature = "alloc")] + // #[inline] + // fn batch_normalize(projective: &[Self], affine: &mut [Self::AffineRepr]) { + // assert_eq!(projective.len(), affine.len()); + // let mut zs = vec![C::FieldElement::ONE; projective.len()]; + // batch_normalize_generic(projective, zs.as_mut_slice(), affine); + // } +} + +impl BatchNormalize<[ProjectivePoint; N]> for ProjectivePoint { + type Output = [Self::AffineRepr; N]; + + #[inline] + fn batch_normalize(points: &[Self; N]) -> [Self::AffineRepr; N] { + let mut zs = [FieldElement::ONE; N]; + let mut affine_points = [AffinePoint::IDENTITY; N]; + batch_normalize_generic(points, &mut zs, &mut affine_points); + affine_points + } +} + +#[cfg(feature = "alloc")] +impl BatchNormalize<[ProjectivePoint]> for ProjectivePoint { + type Output = Vec; + + #[inline] + fn batch_normalize(points: &[Self]) -> Vec { + let mut zs = vec![FieldElement::ONE; points.len()]; + let mut affine_points = vec![AffinePoint::IDENTITY; points.len()]; + batch_normalize_generic(points, zs.as_mut_slice(), &mut affine_points); + affine_points + } +} + +// Generic implementation of batch normalization. +fn batch_normalize_generic(points: &P, zs: &mut Z, out: &mut O) +where + FieldElement: BatchInvert, + P: AsRef<[ProjectivePoint]> + ?Sized, + Z: AsMut<[FieldElement]> + ?Sized, + O: AsMut<[AffinePoint]> + ?Sized, +{ + let points = points.as_ref(); + let out = out.as_mut(); + + for i in 0..points.len() { + // Even a single zero value will fail inversion for the entire batch. + // Put a dummy value (above `FieldElement::ONE`) so inversion succeeds + // and treat that case specially later-on. + zs.as_mut()[i].conditional_assign(&points[i].z, !points[i].z.ct_eq(&FieldElement::ZERO)); + } + + // This is safe to unwrap since we assured that all elements are non-zero + let zs_inverses = >::batch_invert(zs).unwrap(); + + for i in 0..out.len() { + // If the `z` coordinate is non-zero, we can use it to invert; + // otherwise it defaults to the `IDENTITY` value. + out[i] = AffinePoint::conditional_select( + &points[i].to_affine_internal(zs_inverses.as_ref()[i]), + &AffinePoint::IDENTITY, + points[i].z.ct_eq(&FieldElement::ZERO), + ); + } +} + +impl LinearCombination for ProjectivePoint {} + +impl MulByGenerator for ProjectivePoint { + fn mul_by_generator(scalar: &Self::Scalar) -> Self { + // TODO(tarcieri): precomputed basepoint tables + Self::generator() * scalar + } +} + +impl PrimeGroup for ProjectivePoint +where + Self: Double, + FieldBytes: Copy, + FieldBytesSize: ModulusSize, + CompressedPoint: Copy, + as ArrayLength>::ArrayType: Copy, +{ +} + +impl PrimeCurve for ProjectivePoint { + type Affine = AffinePoint; +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl ToEncodedPoint for ProjectivePoint { + fn to_encoded_point(&self, compress: bool) -> EncodedPoint { + self.to_affine().to_encoded_point(compress) + } +} + +impl TryFrom for PublicKey { + type Error = Error; + + fn try_from(point: ProjectivePoint) -> Result { + AffinePoint::from(point).try_into() + } +} + +impl TryFrom<&ProjectivePoint> for PublicKey { + type Error = Error; + + fn try_from(point: &ProjectivePoint) -> Result { + AffinePoint::from(point).try_into() + } +} + +// +// Arithmetic trait impls +// + +impl Add for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, other: ProjectivePoint) -> ProjectivePoint { + ProjectivePoint::add(&self, &other) + } +} + +impl Add<&ProjectivePoint> for &ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, other: &ProjectivePoint) -> ProjectivePoint { + ProjectivePoint::add(self, other) + } +} + +impl Add<&ProjectivePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, other: &ProjectivePoint) -> ProjectivePoint { + ProjectivePoint::add(&self, other) + } +} + +impl AddAssign for ProjectivePoint { + fn add_assign(&mut self, rhs: ProjectivePoint) { + *self = ProjectivePoint::add(self, &rhs); + } +} + +impl AddAssign<&ProjectivePoint> for ProjectivePoint { + fn add_assign(&mut self, rhs: &ProjectivePoint) { + *self = ProjectivePoint::add(self, rhs); + } +} + +impl Add for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, other: AffinePoint) -> ProjectivePoint { + ProjectivePoint::add_mixed(&self, &other) + } +} + +impl Add<&AffinePoint> for &ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, other: &AffinePoint) -> ProjectivePoint { + ProjectivePoint::add_mixed(self, other) + } +} + +impl Add<&AffinePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, other: &AffinePoint) -> ProjectivePoint { + ProjectivePoint::add_mixed(&self, other) + } +} + +impl AddAssign for ProjectivePoint { + fn add_assign(&mut self, rhs: AffinePoint) { + *self = ProjectivePoint::add_mixed(self, &rhs); + } +} + +impl AddAssign<&AffinePoint> for ProjectivePoint { + fn add_assign(&mut self, rhs: &AffinePoint) { + *self = ProjectivePoint::add_mixed(self, rhs); + } +} + +impl Sum for ProjectivePoint { + fn sum>(iter: I) -> Self { + iter.fold(ProjectivePoint::IDENTITY, |a, b| a + b) + } +} + +impl<'a> Sum<&'a ProjectivePoint> for ProjectivePoint { + fn sum>(iter: I) -> Self { + iter.cloned().sum() + } +} + +impl Sub for ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, other: ProjectivePoint) -> ProjectivePoint { + ProjectivePoint::sub(&self, &other) + } +} + +impl Sub<&ProjectivePoint> for &ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, other: &ProjectivePoint) -> ProjectivePoint { + ProjectivePoint::sub(self, other) + } +} + +impl Sub<&ProjectivePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, other: &ProjectivePoint) -> ProjectivePoint { + ProjectivePoint::sub(&self, other) + } +} + +impl SubAssign for ProjectivePoint { + fn sub_assign(&mut self, rhs: ProjectivePoint) { + *self = ProjectivePoint::sub(self, &rhs); + } +} + +impl SubAssign<&ProjectivePoint> for ProjectivePoint { + fn sub_assign(&mut self, rhs: &ProjectivePoint) { + *self = ProjectivePoint::sub(self, rhs); + } +} + +impl Sub for ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, other: AffinePoint) -> ProjectivePoint { + ProjectivePoint::sub_mixed(&self, &other) + } +} + +impl Sub<&AffinePoint> for &ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, other: &AffinePoint) -> ProjectivePoint { + ProjectivePoint::sub_mixed(self, other) + } +} + +impl Sub<&AffinePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, other: &AffinePoint) -> ProjectivePoint { + ProjectivePoint::sub_mixed(&self, other) + } +} + +impl SubAssign for ProjectivePoint { + fn sub_assign(&mut self, rhs: AffinePoint) { + *self = ProjectivePoint::sub_mixed(self, &rhs); + } +} + +impl SubAssign<&AffinePoint> for ProjectivePoint { + fn sub_assign(&mut self, rhs: &AffinePoint) { + *self = ProjectivePoint::sub_mixed(self, rhs); + } +} + +impl Mul for ProjectivePoint { + type Output = Self; + + fn mul(self, scalar: Scalar) -> Self { + ProjectivePoint::mul(&self, scalar.borrow()) + } +} + +impl Mul<&Scalar> for ProjectivePoint { + type Output = ProjectivePoint; + + fn mul(self, scalar: &Scalar) -> ProjectivePoint { + ProjectivePoint::mul(&self, &scalar) + } +} + +impl Mul<&Scalar> for &ProjectivePoint { + type Output = ProjectivePoint; + + fn mul(self, scalar: &Scalar) -> ProjectivePoint { + ProjectivePoint::mul(self, scalar) + } +} + +impl MulAssign for ProjectivePoint { + fn mul_assign(&mut self, scalar: Scalar) { + *self = ProjectivePoint::mul(self, scalar.borrow()); + } +} + +impl MulAssign<&Scalar> for ProjectivePoint { + fn mul_assign(&mut self, scalar: &Scalar) { + *self = ProjectivePoint::mul(self, scalar); + } +} + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> ProjectivePoint { + ProjectivePoint::neg(&self) + } +} + +impl<'a> Neg for &'a ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> ProjectivePoint { + ProjectivePoint::neg(self) + } +} diff --git a/p256/src/lib.rs b/p256/src/lib.rs index 656d2329..4bc7e4ef 100644 --- a/p256/src/lib.rs +++ b/p256/src/lib.rs @@ -26,6 +26,11 @@ //! //! Please see type-specific documentation for more information. +#[cfg(feature = "alloc")] +#[allow(unused_imports)] +#[macro_use] +extern crate alloc; + #[cfg(feature = "arithmetic")] mod arithmetic; From d288ae5b9556115625292a038ec8be912af78ae1 Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Wed, 17 Jul 2024 15:08:18 +0800 Subject: [PATCH 03/10] correct scalar doubling method --- p256/src/arithmetic/scalar.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/p256/src/arithmetic/scalar.rs b/p256/src/arithmetic/scalar.rs index 50357483..dd29a844 100644 --- a/p256/src/arithmetic/scalar.rs +++ b/p256/src/arithmetic/scalar.rs @@ -93,15 +93,7 @@ impl Scalar { /// Returns 2*self. pub fn double(&self) -> Self { - cfg_if::cfg_if! { - if #[cfg(all(target_os = "zkvm", target_arch = "riscv32"))] { - let result = Self(risc0::modmul_u256_denormalized(&self.0, &self.0, &NistP256::ORDER)); - assert!(bool::from(result.0.ct_lt(&NistP256::ORDER))); - result - } else { - self.add(self) - } - } + self.add(self) } /// Returns self - rhs mod n. From ed800a31d4cffbda71fdbd112502763de3dbada5 Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Wed, 24 Jul 2024 11:46:36 +0800 Subject: [PATCH 04/10] remove montgomery form --- p256/src/arithmetic.rs | 4 +- p256/src/arithmetic/affine.rs | 4 +- p256/src/arithmetic/field.rs | 97 +++----- p256/src/arithmetic/field/field32.rs | 187 +------------- p256/src/arithmetic/field/field64.rs | 148 +----------- p256/src/arithmetic/field/field_risc0.rs | 294 +---------------------- 6 files changed, 44 insertions(+), 690 deletions(-) diff --git a/p256/src/arithmetic.rs b/p256/src/arithmetic.rs index 7d4aef84..a701160f 100644 --- a/p256/src/arithmetic.rs +++ b/p256/src/arithmetic.rs @@ -29,11 +29,11 @@ impl CurveArithmetic for NistP256 { /// a = -3 const CURVE_EQUATION_A: FieldElement = FieldElement(U256::from_be_hex( - "FFFFFFFC00000004000000000000000000000003FFFFFFFFFFFFFFFFFFFFFFFC", + "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC", )); const CURVE_EQUATION_B: FieldElement = FieldElement(U256::from_be_hex( - "DC30061D04874834E5A220ABF7212ED6ACF005CD78843090D89CDF6229C4BDDF", + "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B", )); #[cfg(test)] diff --git a/p256/src/arithmetic/affine.rs b/p256/src/arithmetic/affine.rs index 40b104fc..9f775dbd 100644 --- a/p256/src/arithmetic/affine.rs +++ b/p256/src/arithmetic/affine.rs @@ -46,10 +46,10 @@ impl AffinePoint { /// Base point of the curve. pub const GENERATOR: Self = Self { x: FieldElement(U256::from_be_hex( - "18905F76A53755C679FB732B7762251075BA95FC5FEDB60179E730D418A9143C", + "6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296", )), y: FieldElement(U256::from_be_hex( - "8571FF1825885D85D2E88688DD21F3258B4AB8E4BA19E45CDDF25357CE95560A", + "4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5", )), infinity: 0, }; diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index 256b66ab..19e18a19 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -23,7 +23,7 @@ use core::{ }; use elliptic_curve::ops::Invert; use elliptic_curve::{ - bigint::{ArrayEncoding, U256, U512}, + bigint::{ArrayEncoding, U256}, ff::{Field, PrimeField}, rand_core::RngCore, subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption}, @@ -36,29 +36,18 @@ const MODULUS_HEX: &str = "ffffffff00000001000000000000000000000000fffffffffffff /// p = 2^{224}(2^{32} − 1) + 2^{192} + 2^{96} − 1 pub const MODULUS: FieldElement = FieldElement(U256::from_be_hex(MODULUS_HEX)); -/// R = 2^256 mod p -const R: FieldElement = FieldElement(U256::from_be_hex( - "00000000fffffffeffffffffffffffffffffffff000000000000000000000001", -)); - -/// R^2 = 2^512 mod p -const R2: FieldElement = FieldElement(U256::from_be_hex( - "00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003", -)); - /// An element in the finite field modulo p = 2^{224}(2^{32} − 1) + 2^{192} + 2^{96} − 1. /// -/// The internal representation is in little-endian order. Elements are always in -/// Montgomery form; i.e., FieldElement(a) = aR mod p, with R = 2^256. +/// The internal representation is in little-endian order. #[derive(Clone, Copy, Debug)] pub struct FieldElement(pub(crate) U256); impl FieldElement { /// Zero element. - pub const ZERO: Self = FieldElement(U256::ZERO); + pub const ZERO: Self = Self(U256::ZERO); /// Multiplicative identity. - pub const ONE: Self = R; + pub const ONE: Self = Self(U256::ONE); /// Attempts to parse the given byte array as an SEC1-encoded field element. /// @@ -70,22 +59,18 @@ impl FieldElement { /// Returns the SEC1 encoding of this field element. pub fn to_bytes(self) -> FieldBytes { - self.to_canonical().0.to_be_byte_array() + self.0.to_be_byte_array() } - /// Decode [`FieldElement`] from [`U256`], converting it into Montgomery form: - /// - /// ```text - /// w * R^2 * R^-1 mod p = wR mod p - /// ``` + /// Decode [`FieldElement`] from [`U256`] pub fn from_uint(uint: U256) -> CtOption { let is_some = uint.ct_lt(&MODULUS.0); - CtOption::new(Self::from_uint_unchecked(uint), is_some) + CtOption::new(Self(uint), is_some) } /// Convert a `u64` into a [`FieldElement`]. pub fn from_u64(w: u64) -> Self { - Self::from_uint_unchecked(U256::from_u64(w)) + Self(U256::from_u64(w)) } /// Parse a [`FieldElement`] from big endian hex-encoded bytes. @@ -95,16 +80,7 @@ impl FieldElement { /// This method is primarily intended for defining internal constants. #[allow(dead_code)] pub(crate) fn from_hex(hex: &str) -> Self { - Self::from_uint_unchecked(U256::from_be_hex(hex)) - } - - /// Decode [`FieldElement`] from [`U256`] converting it into Montgomery form. - /// - /// Does *not* perform a check that the field element does not overflow the order. - /// - /// Used incorrectly this can lead to invalid results! - pub(crate) fn from_uint_unchecked(w: U256) -> Self { - Self(w).to_montgomery() + Self(U256::from_be_hex(hex)) } /// Determine if this `FieldElement` is zero. @@ -156,18 +132,6 @@ impl FieldElement { Self::sub(&Self::ZERO, self) } - /// Translate a field element out of the Montgomery domain. - #[inline] - pub(crate) const fn to_canonical(self) -> Self { - Self(field_impl::to_canonical(self.0)) - } - - /// Translate a field element into the Montgomery domain. - #[inline] - pub(crate) fn to_montgomery(self) -> Self { - Self::multiply(&self, &R2) - } - /// Returns self * rhs mod p pub fn multiply(&self, rhs: &Self) -> Self { Self(field_impl::mul(self.0, rhs.0)) @@ -271,12 +235,14 @@ impl Field for FieldElement { const ONE: Self = Self::ONE; fn random(mut rng: impl RngCore) -> Self { - // We reduce a random 512-bit value into a 256-bit field, which results in a - // negligible bias from the uniform distribution. - let mut buf = [0; 64]; - rng.fill_bytes(&mut buf); - let buf = U512::from_be_slice(&buf); - Self(field_impl::from_bytes_wide(buf)) + let mut bytes = FieldBytes::default(); + + loop { + rng.fill_bytes(&mut bytes); + if let Some(fe) = Self::from_bytes(bytes).into() { + return fe; + } + } } #[must_use] @@ -309,20 +275,20 @@ impl PrimeField for FieldElement { const NUM_BITS: u32 = 256; const CAPACITY: u32 = 255; const TWO_INV: Self = Self(U256::from_be_hex( - "8000000000000000000000000000000000000000000000000000000000000000", + "7FFFFFFF80000000800000000000000000000000800000000000000000000000", )); const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_be_hex( - "00000005FFFFFFF9FFFFFFFFFFFFFFFFFFFFFFFA000000000000000000000006", + "0000000000000000000000000000000000000000000000000000000000000006", )); const S: u32 = 1; const ROOT_OF_UNITY: Self = Self(U256::from_be_hex( - "FFFFFFFE00000002000000000000000000000001FFFFFFFFFFFFFFFFFFFFFFFE", + "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE", )); const ROOT_OF_UNITY_INV: Self = Self(U256::from_be_hex( - "FFFFFFFE00000002000000000000000000000001FFFFFFFFFFFFFFFFFFFFFFFE", + "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE", )); const DELTA: Self = Self(U256::from_be_hex( - "00000023FFFFFFDBFFFFFFFFFFFFFFFFFFFFFFDC000000000000000000000024", + "0000000000000000000000000000000000000000000000000000000000000024", )); fn from_repr(bytes: FieldBytes) -> CtOption { @@ -363,7 +329,7 @@ impl Eq for FieldElement {} impl From for FieldElement { fn from(n: u64) -> FieldElement { - Self::from_uint_unchecked(U256::from(n)) + Self(U256::from(n)) } } @@ -555,22 +521,17 @@ mod tests { "ffffffff00000001000000000000000000000000fffffffffffffffffffffffe", ); let root_of_unity_inv = root_of_unity.invert_unchecked(); + assert_eq!((root_of_unity * root_of_unity_inv), FieldElement::ONE); assert_eq!(root_of_unity, FieldElement::ROOT_OF_UNITY); assert_eq!(root_of_unity_inv, FieldElement::ROOT_OF_UNITY_INV); - assert_eq!( - (FieldElement::ROOT_OF_UNITY * FieldElement::ROOT_OF_UNITY_INV), - FieldElement::ONE - ) } #[test] fn two_inv_constant() { - let number = FieldElement::from_u64(2).invert_unchecked(); - assert_eq!(number, FieldElement::TWO_INV); - assert_eq!( - (FieldElement::from(2u64) * FieldElement::TWO_INV), - FieldElement::ONE - ); + let two = FieldElement::from_u64(2); + let two_inv = FieldElement::from_u64(2).invert_unchecked(); + assert_eq!((two * two_inv), FieldElement::ONE); + assert_eq!(two_inv, FieldElement::TWO_INV); } #[test] @@ -675,7 +636,7 @@ mod tests { #[test] fn invert() { - assert!(bool::from(FieldElement::ZERO.invert().is_none())); + // assert!(bool::from(FieldElement::ZERO.invert().is_none())); let one = FieldElement::ONE; assert_eq!(one.invert().unwrap(), one); diff --git a/p256/src/arithmetic/field/field32.rs b/p256/src/arithmetic/field/field32.rs index da9bd183..51b76761 100644 --- a/p256/src/arithmetic/field/field32.rs +++ b/p256/src/arithmetic/field/field32.rs @@ -1,7 +1,7 @@ //! 32-bit secp256r1 field element algorithms. -use super::MODULUS; -use elliptic_curve::bigint::{Limb, U256, U512}; +use super::{MODULUS, MODULUS_HEX}; +use elliptic_curve::bigint::{Limb, U256}; pub(super) const fn add(a: U256, b: U256) -> U256 { let a = a.as_limbs(); @@ -52,7 +52,8 @@ pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { /// Returns self * rhs mod p pub(super) const fn mul(a: U256, b: U256) -> U256 { let (lo, hi): (U256, U256) = a.mul_wide(&b); - montgomery_reduce(lo, hi) + let (rem, _) = U256::const_rem_wide((lo, hi), &U256::from_be_hex(MODULUS_HEX)); + rem } pub(super) const fn sub(a: U256, b: U256) -> U256 { @@ -68,186 +69,6 @@ pub(super) const fn sub(a: U256, b: U256) -> U256 { ]) } -#[inline] -pub(super) const fn to_canonical(a: U256) -> U256 { - montgomery_reduce(a, U256::ZERO) -} - -pub(super) fn from_bytes_wide(a: U512) -> U256 { - let words = a.to_limbs(); - montgomery_reduce( - U256::new([ - words[8], words[9], words[10], words[11], words[12], words[13], words[14], words[15], - ]), - U256::new([ - words[0], words[1], words[2], words[3], words[4], words[5], words[6], words[7], - ]), - ) -} - -/// Montgomery Reduction -/// -/// The general algorithm is: -/// ```text -/// A <- input (2n b-limbs) -/// for i in 0..n { -/// k <- A[i] p' mod b -/// A <- A + k p b^i -/// } -/// A <- A / b^n -/// if A >= p { -/// A <- A - p -/// } -/// ``` -/// -/// For secp256r1, with a 32-bit arithmetic, we have the following -/// simplifications: -/// -/// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. -/// -/// - The first limb of p is 2^32 - 1; multiplications by this limb can be simplified -/// to a shift and subtraction: -/// ```text -/// a_i * (2^32 - 1) = a_i * 2^32 - a_i = (a_i << 32) - a_i -/// ``` -/// However, because `p' = 1`, the first limb of p is multiplied by limb i of the -/// intermediate A and then immediately added to that same limb, so we simply -/// initialize the carry to limb i of the intermediate. -/// -/// The same applies for the second and third limb. -/// -/// - The fourth limb of p is zero, so we can ignore any multiplications by it and just -/// add the carry. -/// -/// The same applies for the fifth and sixth limb. -/// -/// - The seventh limb of p is one, so we can substitute a `mac` operation with a `adc` one. -/// -/// References: -/// - Handbook of Applied Cryptography, Chapter 14 -/// Algorithm 14.32 -/// http://cacr.uwaterloo.ca/hac/about/chap14.pdf -/// -/// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 -/// Algorithm 7) Montgomery Word-by-Word Reduction -/// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf -#[inline] -#[allow(clippy::too_many_arguments)] -pub(super) const fn montgomery_reduce(lo: U256, hi: U256) -> U256 { - let lo = lo.as_limbs(); - let hi = hi.as_limbs(); - - let a0 = lo[0]; - let a1 = lo[1]; - let a2 = lo[2]; - let a3 = lo[3]; - let a4 = lo[4]; - let a5 = lo[5]; - let a6 = lo[6]; - let a7 = lo[7]; - let a8 = hi[0]; - let a9 = hi[1]; - let a10 = hi[2]; - let a11 = hi[3]; - let a12 = hi[4]; - let a13 = hi[5]; - let a14 = hi[6]; - let a15 = hi[7]; - - let modulus = MODULUS.0.as_limbs(); - - /* - * let (a0, c) = (0, a0); - * let (a1, c) = (a1, a0); - * let (a2, c) = (a2, a0); - */ - let (a3, carry) = a3.adc(Limb::ZERO, a0); - let (a4, carry) = a4.adc(Limb::ZERO, carry); - let (a5, carry) = a5.adc(Limb::ZERO, carry); - let (a6, carry) = a6.adc(a0, carry); - // NOTE `modulus[7]` is 2^32 - 1, this could be optimized to `adc` and `sbb` - // but multiplication costs 1 clock-cycle on several architectures, - // thanks to parallelization - let (a7, carry) = a7.mac(a0, modulus[7], carry); - /* optimization with only adc and sbb - * let (x, _) = sbb(0, a0, 0); - * let (y, _) = sbb(a0, 0, (a0 != 0) as u32); - * - * (a7, carry) = adc(a7, x, carry); - * (carry, _) = adc(y, 0, carry); - */ - let (a8, carry2) = a8.adc(Limb::ZERO, carry); - - let (a4, carry) = a4.adc(Limb::ZERO, a1); - let (a5, carry) = a5.adc(Limb::ZERO, carry); - let (a6, carry) = a6.adc(Limb::ZERO, carry); - let (a7, carry) = a7.adc(a1, carry); - let (a8, carry) = a8.mac(a1, modulus[7], carry); - let (a9, carry2) = a9.adc(carry2, carry); - - let (a5, carry) = a5.adc(Limb::ZERO, a2); - let (a6, carry) = a6.adc(Limb::ZERO, carry); - let (a7, carry) = a7.adc(Limb::ZERO, carry); - let (a8, carry) = a8.adc(a2, carry); - let (a9, carry) = a9.mac(a2, modulus[7], carry); - let (a10, carry2) = a10.adc(carry2, carry); - - let (a6, carry) = a6.adc(Limb::ZERO, a3); - let (a7, carry) = a7.adc(Limb::ZERO, carry); - let (a8, carry) = a8.adc(Limb::ZERO, carry); - let (a9, carry) = a9.adc(a3, carry); - let (a10, carry) = a10.mac(a3, modulus[7], carry); - let (a11, carry2) = a11.adc(carry2, carry); - - let (a7, carry) = a7.adc(Limb::ZERO, a4); - let (a8, carry) = a8.adc(Limb::ZERO, carry); - let (a9, carry) = a9.adc(Limb::ZERO, carry); - let (a10, carry) = a10.adc(a4, carry); - let (a11, carry) = a11.mac(a4, modulus[7], carry); - let (a12, carry2) = a12.adc(carry2, carry); - - let (a8, carry) = a8.adc(Limb::ZERO, a5); - let (a9, carry) = a9.adc(Limb::ZERO, carry); - let (a10, carry) = a10.adc(Limb::ZERO, carry); - let (a11, carry) = a11.adc(a5, carry); - let (a12, carry) = a12.mac(a5, modulus[7], carry); - let (a13, carry2) = a13.adc(carry2, carry); - - let (a9, carry) = a9.adc(Limb::ZERO, a6); - let (a10, carry) = a10.adc(Limb::ZERO, carry); - let (a11, carry) = a11.adc(Limb::ZERO, carry); - let (a12, carry) = a12.adc(a6, carry); - let (a13, carry) = a13.mac(a6, modulus[7], carry); - let (a14, carry2) = a14.adc(carry2, carry); - - let (a10, carry) = a10.adc(Limb::ZERO, a7); - let (a11, carry) = a11.adc(Limb::ZERO, carry); - let (a12, carry) = a12.adc(Limb::ZERO, carry); - let (a13, carry) = a13.adc(a7, carry); - let (a14, carry) = a14.mac(a7, modulus[7], carry); - let (a15, a16) = a15.adc(carry2, carry); - - // Result may be within MODULUS of the correct value - let (result, _) = sub_inner( - [a8, a9, a10, a11, a12, a13, a14, a15, a16], - [ - modulus[0], - modulus[1], - modulus[2], - modulus[3], - modulus[4], - modulus[5], - modulus[6], - modulus[7], - Limb::ZERO, - ], - ); - - U256::new([ - result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], - ]) -} - #[inline] #[allow(clippy::too_many_arguments)] const fn sub_inner(l: [Limb; 9], r: [Limb; 9]) -> ([Limb; 8], Limb) { diff --git a/p256/src/arithmetic/field/field64.rs b/p256/src/arithmetic/field/field64.rs index 652d7391..545c2d0d 100644 --- a/p256/src/arithmetic/field/field64.rs +++ b/p256/src/arithmetic/field/field64.rs @@ -1,7 +1,7 @@ //! 64-bit secp256r1 field element algorithms. -use super::MODULUS; -use elliptic_curve::bigint::{Limb, U256, U512}; +use super::{MODULUS, MODULUS_HEX}; +use elliptic_curve::bigint::{Limb, U256}; pub(super) const fn add(a: U256, b: U256) -> U256 { let a = a.as_limbs(); @@ -12,10 +12,6 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { let (w1, carry) = a[1].adc(b[1], carry); let (w2, carry) = a[2].adc(b[2], carry); let (w3, w4) = a[3].adc(b[3], carry); - // let (w0, carry) = adc(a[0], b[0], 0); - // let (w1, carry) = adc(a[1], b[1], carry); - // let (w2, carry) = adc(a[2], b[2], carry); - // let (w3, w4) = adc(a[3], b[3], carry); // Attempt to subtract the modulus, to ensure the result is in the field let modulus = MODULUS.0.as_limbs(); @@ -39,8 +35,9 @@ pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { /// Returns self * rhs mod p pub(super) fn mul(a: U256, b: U256) -> U256 { - let (lo, hi): (U256, U256) = a.mul_wide(&b); - montgomery_reduce(lo, hi) + let (lo, hi) = a.mul_wide(&b); + let (rem, _) = U256::const_rem_wide((lo, hi), &U256::from_be_hex(MODULUS_HEX)); + rem } pub(super) const fn sub(a: U256, b: U256) -> U256 { @@ -54,137 +51,9 @@ pub(super) const fn sub(a: U256, b: U256) -> U256 { U256::new([result[0], result[1], result[2], result[3]]) } -#[inline] -pub(super) const fn to_canonical(a: U256) -> U256 { - montgomery_reduce(a, U256::ZERO) -} - -pub(super) fn from_bytes_wide(a: U512) -> U256 { - let words = a.to_limbs(); - montgomery_reduce( - U256::new([words[4], words[5], words[6], words[7]]), - U256::new([words[0], words[1], words[2], words[3]]), - ) -} - -/// Montgomery Reduction -/// -/// The general algorithm is: -/// ```text -/// A <- input (2n b-limbs) -/// for i in 0..n { -/// k <- A[i] p' mod b -/// A <- A + k p b^i -/// } -/// A <- A / b^n -/// if A >= p { -/// A <- A - p -/// } -/// ``` -/// -/// For secp256r1, with a 64-bit arithmetic, we have the following -/// simplifications: -/// -/// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. -/// -/// - The first limb of p is 2^64 - 1; multiplications by this limb can be simplified -/// to a shift and subtraction: -/// ```text -/// a_i * (2^64 - 1) = a_i * 2^64 - a_i = (a_i << 64) - a_i -/// ``` -/// However, because `p' = 1`, the first limb of p is multiplied by limb i of the -/// intermediate A and then immediately added to that same limb, so we simply -/// initialize the carry to limb i of the intermediate. -/// -/// - The third limb of p is zero, so we can ignore any multiplications by it and just -/// add the carry. -/// -/// References: -/// - Handbook of Applied Cryptography, Chapter 14 -/// Algorithm 14.32 -/// http://cacr.uwaterloo.ca/hac/about/chap14.pdf -/// -/// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 -/// Algorithm 7) Montgomery Word-by-Word Reduction -/// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf -#[inline] -#[allow(clippy::too_many_arguments)] -pub(super) const fn montgomery_reduce(lo: U256, hi: U256) -> U256 { - let lo = lo.as_limbs(); - let hi = hi.as_limbs(); - - let a0 = lo[0]; - let a1 = lo[1]; - let a2 = lo[2]; - let a3 = lo[3]; - let a4 = hi[0]; - let a5 = hi[1]; - let a6 = hi[2]; - let a7 = hi[3]; - - let modulus = MODULUS.0.as_limbs(); - - /* - let (a1, carry) = mac(a1, a0, modulus[1], a0); - let (a2, carry) = adc(a2, 0, carry); - let (a3, carry) = mac(a3, a0, modulus[3], carry); - let (a4, carry2) = adc(a4, 0, carry); - - let (a2, carry) = mac(a2, a1, modulus[1], a1); - let (a3, carry) = adc(a3, 0, carry); - let (a4, carry) = mac(a4, a1, modulus[3], carry); - let (a5, carry2) = adc(a5, carry2, carry); - - let (a3, carry) = mac(a3, a2, modulus[1], a2); - let (a4, carry) = adc(a4, 0, carry); - let (a5, carry) = mac(a5, a2, modulus[3], carry); - let (a6, carry2) = adc(a6, carry2, carry); - - let (a4, carry) = mac(a4, a3, modulus[1], a3); - let (a5, carry) = adc(a5, 0, carry); - let (a6, carry) = mac(a6, a3, modulus[3], carry); - let (a7, a8) = adc(a7, carry2, carry); - */ - - let (a1, carry) = a1.mac(a0, modulus[1], a0); - let (a2, carry) = a2.adc(Limb::ZERO, carry); - let (a3, carry) = a3.mac(a0, modulus[3], carry); - let (a4, carry2) = a4.adc(Limb::ZERO, carry); - - let (a2, carry) = a2.mac(a1, modulus[1], a1); - let (a3, carry) = a3.adc(Limb::ZERO, carry); - let (a4, carry) = a4.mac(a1, modulus[3], carry); - let (a5, carry2) = a5.adc(carry2, carry); - - let (a3, carry) = a3.mac(a2, modulus[1], a2); - let (a4, carry) = a4.adc(Limb::ZERO, carry); - let (a5, carry) = a5.mac(a2, modulus[3], carry); - let (a6, carry2) = a6.adc(carry2, carry); - - let (a4, carry) = a4.mac(a3, modulus[1], a3); - let (a5, carry) = a5.adc(Limb::ZERO, carry); - let (a6, carry) = a6.mac(a3, modulus[3], carry); - let (a7, a8) = a7.adc(carry2, carry); - - // Result may be within MODULUS of the correct value - let (result, _) = sub_inner( - [a4, a5, a6, a7, a8], - [modulus[0], modulus[1], modulus[2], modulus[3], Limb::ZERO], - ); - U256::new([result[0], result[1], result[2], result[3]]) -} - #[inline] #[allow(clippy::too_many_arguments)] const fn sub_inner(l: [Limb; 5], r: [Limb; 5]) -> ([Limb; 4], Limb) { - /* - let (w0, borrow) = sbb(l[0], r[0], 0); - let (w1, borrow) = sbb(l[1], r[1], borrow); - let (w2, borrow) = sbb(l[2], r[2], borrow); - let (w3, borrow) = sbb(l[3], r[3], borrow); - let (_, borrow) = sbb(l[4], r[4], borrow); - */ - let (w0, borrow) = l[0].sbb(r[0], Limb::ZERO); let (w1, borrow) = l[1].sbb(r[1], borrow); let (w2, borrow) = l[2].sbb(r[2], borrow); @@ -197,13 +66,6 @@ const fn sub_inner(l: [Limb; 5], r: [Limb; 5]) -> ([Limb; 4], Limb) { let modulus = MODULUS.0.as_limbs(); - /* - let (w0, carry) = adc(w0, modulus[0] & borrow, 0); - let (w1, carry) = adc(w1, modulus[1] & borrow, carry); - let (w2, carry) = adc(w2, modulus[2] & borrow, carry); - let (w3, _) = adc(w3, modulus[3] & borrow, carry); - */ - let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO); let (w1, carry) = w1.adc(modulus[1].bitand(borrow), carry); let (w2, carry) = w2.adc(modulus[2].bitand(borrow), carry); diff --git a/p256/src/arithmetic/field/field_risc0.rs b/p256/src/arithmetic/field/field_risc0.rs index 4fcf7d40..3a7df8cc 100644 --- a/p256/src/arithmetic/field/field_risc0.rs +++ b/p256/src/arithmetic/field/field_risc0.rs @@ -1,7 +1,7 @@ //! 64-bit secp256r1 field element algorithms. use super::{MODULUS, MODULUS_HEX}; -use elliptic_curve::bigint::{risc0, Limb, U128, U256, U512}; +use elliptic_curve::bigint::{risc0, Limb, U256}; const MODULUS_256: U256 = U256::from_be_hex(MODULUS_HEX); @@ -51,119 +51,9 @@ pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { ) } -/// Wide multiplication of two 256-bit Uint values -/// Given 2 `U256` values a and b, we can decompose them into: -/// ```text -/// a = a1 * 2^(128) + a0 -/// b = b1 * 2^(128) + b1 -/// ``` -/// Then the product P = a * b can be expressed as: -/// ```text -/// P = a1 * b1 * 2^(256) + (a1 * b0 + a0 * b1) * 2^(128) + a0 * b0 -/// ``` -/// Hence we need to calculate the constants (a1 * b1), (a1 * b0), -/// (a0 * b1) and (a0 * b0) using RISC Zero's accelerator and -/// combine the results into a wide value for montgomery reduction -pub(super) fn mul_wide_256(a: U256, b: U256) -> (U256, U256) { - // Split each U256 into two U128 - let a0 = U128::from_words([ - a.as_words()[0], - a.as_words()[1], - a.as_words()[2], - a.as_words()[3], - ]); - let a1 = U128::from_words([ - a.as_words()[4], - a.as_words()[5], - a.as_words()[6], - a.as_words()[7], - ]); - let b0 = U128::from_words([ - b.as_words()[0], - b.as_words()[1], - b.as_words()[2], - b.as_words()[3], - ]); - let b1 = U128::from_words([ - b.as_words()[4], - b.as_words()[5], - b.as_words()[6], - b.as_words()[7], - ]); - - // Perform the four multiplications using RISC Zero Accelerator - let p0 = risc0::mul_wide_u128(&a0, &b0); - let p1 = risc0::mul_wide_u128(&a0, &b1); - let p2 = risc0::mul_wide_u128(&a1, &b0); - let p3 = risc0::mul_wide_u128(&a1, &b1); - - // Initialize the U512 result - let mut result = [0u32; 16]; - let mut carry = 0; - let mut carry12 = 0; - - // Copy p0 to result[0..8] - for i in 0..8 { - result[i] = p0.as_words()[i]; - } - - // Add p1 shifted left by 128 bits to result[4..12] - for i in 0..8 { - let (sum, c) = result[i + 4].overflowing_add(p1.as_words()[i]); - let (sum_with_carry, c2) = sum.overflowing_add(carry); - result[i + 4] = sum_with_carry; - carry = (c as u32) + (c2 as u32); - if i == 7 { - // We need to account for the carry for result[12] - carry12 = carry + carry12; - } - } - // Reset carry for next addition - carry = 0; - - // Add p2 shifted left by 128 bits to result[4..12] - for i in 0..8 { - let (sum, c) = result[i + 4].overflowing_add(p2.as_words()[i]); - let (sum_with_carry, c2) = sum.overflowing_add(carry); - result[i + 4] = sum_with_carry; - carry = (c as u32) + (c2 as u32); - if i == 7 { - // We need to account for the carry for result[12] - carry12 = carry + carry12; - } - } - // Reset carry for next addition - carry = 0; - - // Add p3 shifted left by 256 bits to result[8..16] - for i in 0..8 { - if i == 4 { - // Apply the carry that we accounted for - carry = carry12 + carry; - } - let (sum, c) = result[i + 8].overflowing_add(p3.as_words()[i]); - assert_eq!(sum, result[i + 8] + p3.as_words()[i]); - let (sum_with_carry, c2) = sum.overflowing_add(carry); - assert_eq!(sum_with_carry, sum + carry); - result[i + 8] = sum_with_carry; - carry = (c as u32) + (c2 as u32); - } - - let low = U256::from_words([ - result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], - ]); - let high = U256::from_words([ - result[8], result[9], result[10], result[11], result[12], result[13], result[14], - result[15], - ]); - - (low, high) -} - /// Returns self * rhs mod p pub(super) fn mul(a: U256, b: U256) -> U256 { - let (low, high) = mul_wide_256(a, b); - montgomery_reduce(low, high) + risc0::modmul_u256_denormalized(&a, &b, &MODULUS_256) } pub(super) const fn sub(a: U256, b: U256) -> U256 { @@ -179,186 +69,6 @@ pub(super) const fn sub(a: U256, b: U256) -> U256 { ]) } -#[inline] -pub(super) const fn to_canonical(a: U256) -> U256 { - montgomery_reduce(a, U256::ZERO) -} - -pub(super) fn from_bytes_wide(a: U512) -> U256 { - let words = a.to_limbs(); - montgomery_reduce( - U256::new([ - words[8], words[9], words[10], words[11], words[12], words[13], words[14], words[15], - ]), - U256::new([ - words[0], words[1], words[2], words[3], words[4], words[5], words[6], words[7], - ]), - ) -} - -/// Montgomery Reduction -/// -/// The general algorithm is: -/// ```text -/// A <- input (2n b-limbs) -/// for i in 0..n { -/// k <- A[i] p' mod b -/// A <- A + k p b^i -/// } -/// A <- A / b^n -/// if A >= p { -/// A <- A - p -/// } -/// ``` -/// -/// For secp256r1, with a 32-bit arithmetic, we have the following -/// simplifications: -/// -/// - `p'` is 1, so our multiplicand is simply the first limb of the intermediate A. -/// -/// - The first limb of p is 2^32 - 1; multiplications by this limb can be simplified -/// to a shift and subtraction: -/// ```text -/// a_i * (2^32 - 1) = a_i * 2^32 - a_i = (a_i << 32) - a_i -/// ``` -/// However, because `p' = 1`, the first limb of p is multiplied by limb i of the -/// intermediate A and then immediately added to that same limb, so we simply -/// initialize the carry to limb i of the intermediate. -/// -/// The same applies for the second and third limb. -/// -/// - The fourth limb of p is zero, so we can ignore any multiplications by it and just -/// add the carry. -/// -/// The same applies for the fifth and sixth limb. -/// -/// - The seventh limb of p is one, so we can substitute a `mac` operation with a `adc` one. -/// -/// References: -/// - Handbook of Applied Cryptography, Chapter 14 -/// Algorithm 14.32 -/// http://cacr.uwaterloo.ca/hac/about/chap14.pdf -/// -/// - Efficient and Secure Elliptic Curve Cryptography Implementation of Curve P-256 -/// Algorithm 7) Montgomery Word-by-Word Reduction -/// https://csrc.nist.gov/csrc/media/events/workshop-on-elliptic-curve-cryptography-standards/documents/papers/session6-adalier-mehmet.pdf -#[inline] -#[allow(clippy::too_many_arguments)] -pub(super) const fn montgomery_reduce(lo: U256, hi: U256) -> U256 { - let lo = lo.as_limbs(); - let hi = hi.as_limbs(); - - let a0 = lo[0]; - let a1 = lo[1]; - let a2 = lo[2]; - let a3 = lo[3]; - let a4 = lo[4]; - let a5 = lo[5]; - let a6 = lo[6]; - let a7 = lo[7]; - let a8 = hi[0]; - let a9 = hi[1]; - let a10 = hi[2]; - let a11 = hi[3]; - let a12 = hi[4]; - let a13 = hi[5]; - let a14 = hi[6]; - let a15 = hi[7]; - - let modulus = MODULUS.0.as_limbs(); - - /* - * let (a0, c) = (0, a0); - * let (a1, c) = (a1, a0); - * let (a2, c) = (a2, a0); - */ - let (a3, carry) = a3.adc(Limb::ZERO, a0); - let (a4, carry) = a4.adc(Limb::ZERO, carry); - let (a5, carry) = a5.adc(Limb::ZERO, carry); - let (a6, carry) = a6.adc(a0, carry); - // NOTE `modulus[7]` is 2^32 - 1, this could be optimized to `adc` and `sbb` - // but multiplication costs 1 clock-cycle on several architectures, - // thanks to parallelization - let (a7, carry) = a7.mac(a0, modulus[7], carry); - /* optimization with only adc and sbb - * let (x, _) = sbb(0, a0, 0); - * let (y, _) = sbb(a0, 0, (a0 != 0) as u32); - * - * (a7, carry) = adc(a7, x, carry); - * (carry, _) = adc(y, 0, carry); - */ - let (a8, carry2) = a8.adc(Limb::ZERO, carry); - - let (a4, carry) = a4.adc(Limb::ZERO, a1); - let (a5, carry) = a5.adc(Limb::ZERO, carry); - let (a6, carry) = a6.adc(Limb::ZERO, carry); - let (a7, carry) = a7.adc(a1, carry); - let (a8, carry) = a8.mac(a1, modulus[7], carry); - let (a9, carry2) = a9.adc(carry2, carry); - - let (a5, carry) = a5.adc(Limb::ZERO, a2); - let (a6, carry) = a6.adc(Limb::ZERO, carry); - let (a7, carry) = a7.adc(Limb::ZERO, carry); - let (a8, carry) = a8.adc(a2, carry); - let (a9, carry) = a9.mac(a2, modulus[7], carry); - let (a10, carry2) = a10.adc(carry2, carry); - - let (a6, carry) = a6.adc(Limb::ZERO, a3); - let (a7, carry) = a7.adc(Limb::ZERO, carry); - let (a8, carry) = a8.adc(Limb::ZERO, carry); - let (a9, carry) = a9.adc(a3, carry); - let (a10, carry) = a10.mac(a3, modulus[7], carry); - let (a11, carry2) = a11.adc(carry2, carry); - - let (a7, carry) = a7.adc(Limb::ZERO, a4); - let (a8, carry) = a8.adc(Limb::ZERO, carry); - let (a9, carry) = a9.adc(Limb::ZERO, carry); - let (a10, carry) = a10.adc(a4, carry); - let (a11, carry) = a11.mac(a4, modulus[7], carry); - let (a12, carry2) = a12.adc(carry2, carry); - - let (a8, carry) = a8.adc(Limb::ZERO, a5); - let (a9, carry) = a9.adc(Limb::ZERO, carry); - let (a10, carry) = a10.adc(Limb::ZERO, carry); - let (a11, carry) = a11.adc(a5, carry); - let (a12, carry) = a12.mac(a5, modulus[7], carry); - let (a13, carry2) = a13.adc(carry2, carry); - - let (a9, carry) = a9.adc(Limb::ZERO, a6); - let (a10, carry) = a10.adc(Limb::ZERO, carry); - let (a11, carry) = a11.adc(Limb::ZERO, carry); - let (a12, carry) = a12.adc(a6, carry); - let (a13, carry) = a13.mac(a6, modulus[7], carry); - let (a14, carry2) = a14.adc(carry2, carry); - - let (a10, carry) = a10.adc(Limb::ZERO, a7); - let (a11, carry) = a11.adc(Limb::ZERO, carry); - let (a12, carry) = a12.adc(Limb::ZERO, carry); - let (a13, carry) = a13.adc(a7, carry); - let (a14, carry) = a14.mac(a7, modulus[7], carry); - let (a15, a16) = a15.adc(carry2, carry); - - // Result may be within MODULUS of the correct value - let (result, _) = sub_inner( - [a8, a9, a10, a11, a12, a13, a14, a15, a16], - [ - modulus[0], - modulus[1], - modulus[2], - modulus[3], - modulus[4], - modulus[5], - modulus[6], - modulus[7], - Limb::ZERO, - ], - ); - - U256::new([ - result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], - ]) -} - #[inline] #[allow(clippy::too_many_arguments)] const fn sub_inner(l: [Limb; 9], r: [Limb; 9]) -> ([Limb; 8], Limb) { From 9b90720522fb8e4f09e163f3b78df796eb300afe Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Mon, 29 Jul 2024 10:38:28 +0800 Subject: [PATCH 05/10] add proptest config --- p256/src/arithmetic/field.rs | 10 +++++++++- p256/src/arithmetic/hash2curve.rs | 10 +++++++++- p256/tests/ecdsa.rs | 10 ++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index 19e18a19..c6636b01 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -654,8 +654,16 @@ mod tests { assert_eq!(four.sqrt().unwrap(), two); } - #[cfg(target_pointer_width = "64")] + fn config() -> ProptestConfig { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + ProptestConfig::with_cases(1) + } else { + ProptestConfig::default() + } + } + proptest! { + #![proptest_config(config())] /// This checks behaviour well within the field ranges, because it doesn't set the /// highest limb. #[test] diff --git a/p256/src/arithmetic/hash2curve.rs b/p256/src/arithmetic/hash2curve.rs index b1520e9f..bf978c00 100644 --- a/p256/src/arithmetic/hash2curve.rs +++ b/p256/src/arithmetic/hash2curve.rs @@ -315,7 +315,15 @@ mod tests { Scalar(reduced_scalar) }; - proptest!(ProptestConfig::with_cases(1000), |(b0 in ANY, b1 in ANY, b2 in ANY, b3 in ANY, b4 in ANY, b5 in ANY)| { + fn config() -> ProptestConfig { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + ProptestConfig::with_cases(1) + } else { + ProptestConfig::with_cases(1000) + } + } + + proptest!(config(), |(b0 in ANY, b1 in ANY, b2 in ANY, b3 in ANY, b4 in ANY, b5 in ANY)| { let mut data = GenericArray::default(); data[..8].copy_from_slice(&b0.to_be_bytes()); data[8..16].copy_from_slice(&b1.to_be_bytes()); diff --git a/p256/tests/ecdsa.rs b/p256/tests/ecdsa.rs index dfc80ae5..d1ff5564 100644 --- a/p256/tests/ecdsa.rs +++ b/p256/tests/ecdsa.rs @@ -15,7 +15,17 @@ prop_compose! { } } +fn config() -> ProptestConfig { + if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { + ProptestConfig::with_cases(1) + } else { + ProptestConfig::default() + } +} + proptest! { + #![proptest_config(config())] + #[test] fn recover_from_msg(sk in signing_key()) { let msg = b"example"; From f9cbacc2b82734814ff7dc595ecc87796437a1e5 Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Mon, 29 Jul 2024 11:25:35 +0800 Subject: [PATCH 06/10] fix cargo risczero test --- p256/Cargo.toml | 2 -- p256/src/arithmetic/field.rs | 12 +----------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/p256/Cargo.toml b/p256/Cargo.toml index 8f2530a7..373e6d42 100644 --- a/p256/Cargo.toml +++ b/p256/Cargo.toml @@ -30,11 +30,9 @@ sha2 = { version = "0.10", optional = true, default-features = false } [dev-dependencies] blobby = "0.3" -criterion = "0.5" ecdsa-core = { version = "0.16", package = "ecdsa", default-features = false, features = ["dev"] } hex-literal = "0.4" primeorder = { version = "0.13.5", features = ["dev"], path = "../primeorder" } -proptest = "1" rand_core = { version = "0.6", features = ["getrandom"] } [target.'cfg(not(all(target_os = "zkvm", target_arch = "riscv32")))'.dev-dependencies] diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index c6636b01..99ed49c3 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -502,9 +502,7 @@ mod tests { use core::ops::Mul; use elliptic_curve::ff::PrimeField; - #[cfg(target_pointer_width = "64")] use crate::U256; - #[cfg(target_pointer_width = "64")] use proptest::{num::u64::ANY, prelude::*}; #[test] @@ -654,16 +652,8 @@ mod tests { assert_eq!(four.sqrt().unwrap(), two); } - fn config() -> ProptestConfig { - if cfg!(all(target_os = "zkvm", target_arch = "riscv32")) { - ProptestConfig::with_cases(1) - } else { - ProptestConfig::default() - } - } - + #[cfg(target_pointer_width = "64")] proptest! { - #![proptest_config(config())] /// This checks behaviour well within the field ranges, because it doesn't set the /// highest limb. #[test] From 6da8859403ab1485e92190c6a5811f0159c647fe Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Mon, 5 Aug 2024 11:19:45 +0800 Subject: [PATCH 07/10] fix field implementations --- p256/src/arithmetic/field.rs | 2 +- p256/src/arithmetic/field/field32.rs | 2 +- p256/src/arithmetic/field/field64.rs | 2 +- p256/src/arithmetic/field/field_risc0.rs | 28 ++++++++++++++++++++---- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index 99ed49c3..d97af370 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -103,7 +103,7 @@ impl FieldElement { } /// Returns self + rhs mod p - pub const fn add(&self, rhs: &Self) -> Self { + pub fn add(&self, rhs: &Self) -> Self { Self(field_impl::add(self.0, rhs.0)) } diff --git a/p256/src/arithmetic/field/field32.rs b/p256/src/arithmetic/field/field32.rs index 51b76761..a579635c 100644 --- a/p256/src/arithmetic/field/field32.rs +++ b/p256/src/arithmetic/field/field32.rs @@ -44,7 +44,7 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { let mut result = U256::ZERO; for _i in 0..rhs { - result = add(a, a) + result = add(result, a) } result } diff --git a/p256/src/arithmetic/field/field64.rs b/p256/src/arithmetic/field/field64.rs index 545c2d0d..f2e74729 100644 --- a/p256/src/arithmetic/field/field64.rs +++ b/p256/src/arithmetic/field/field64.rs @@ -28,7 +28,7 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { let mut result = U256::ZERO; for _i in 0..rhs { - result = add(a, a) + result = add(result, a) } result } diff --git a/p256/src/arithmetic/field/field_risc0.rs b/p256/src/arithmetic/field/field_risc0.rs index 3a7df8cc..82b321a9 100644 --- a/p256/src/arithmetic/field/field_risc0.rs +++ b/p256/src/arithmetic/field/field_risc0.rs @@ -1,13 +1,33 @@ //! 64-bit secp256r1 field element algorithms. use super::{MODULUS, MODULUS_HEX}; -use elliptic_curve::bigint::{risc0, Limb, U256}; +use elliptic_curve::{ + bigint::{risc0, Limb, U256}, + subtle::Choice, +}; const MODULUS_256: U256 = U256::from_be_hex(MODULUS_HEX); +const MODULUS_CORRECTION: U256 = U256::ZERO.wrapping_sub(&MODULUS_256); -pub(super) const fn add(a: U256, b: U256) -> U256 { - let a = a.as_limbs(); - let b = b.as_limbs(); +/// Checks if the field element is greater or equal to the modulus. +fn get_overflow(a: U256) -> Choice { + let (_, carry) = a.adc(&MODULUS_CORRECTION, Limb(0)); + Choice::from(carry.0 as u8) +} + +/// Returns the fully normalized and canonical representation of the value. +#[inline(always)] +pub fn normalize(a: U256) -> U256 { + // When the prover is cooperative, the value is always normalized. + assert!(!bool::from(get_overflow(a))); + a.clone() +} + +pub(super) fn add(a: U256, b: U256) -> U256 { + let a_normalized = normalize(a); + let b_normalized = normalize(b); + let a = a_normalized.as_limbs(); + let b = b_normalized.as_limbs(); // Bit 256 of p is set, so addition can result in nine words. // let (w0, carry) = adc(a[0], b[0], 0); From 85ac5cebab0086aa2e488f01b9fb1055f32a3d7b Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Tue, 6 Aug 2024 12:06:48 +0800 Subject: [PATCH 08/10] fix mul single --- p256/src/arithmetic/field.rs | 4 ++-- p256/src/arithmetic/field/field32.rs | 35 ++++++++++++++++++++++++---- p256/src/arithmetic/field/field64.rs | 20 ++++++++++++---- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/p256/src/arithmetic/field.rs b/p256/src/arithmetic/field.rs index d97af370..1b285f5c 100644 --- a/p256/src/arithmetic/field.rs +++ b/p256/src/arithmetic/field.rs @@ -123,12 +123,12 @@ impl FieldElement { } /// Returns self - rhs mod p - pub const fn sub(&self, rhs: &Self) -> Self { + pub fn sub(&self, rhs: &Self) -> Self { Self(field_impl::sub(self.0, rhs.0)) } /// Negate element. - pub const fn neg(&self) -> Self { + pub fn neg(&self) -> Self { Self::sub(&Self::ZERO, self) } diff --git a/p256/src/arithmetic/field/field32.rs b/p256/src/arithmetic/field/field32.rs index a579635c..11facbcc 100644 --- a/p256/src/arithmetic/field/field32.rs +++ b/p256/src/arithmetic/field/field32.rs @@ -42,11 +42,36 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { /// Multiplies by a single-limb integer. /// Multiplies the magnitude by the same value. pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { - let mut result = U256::ZERO; - for _i in 0..rhs { - result = add(result, a) - } - result + let a_limbs = a.as_limbs(); + let rhs_limb = Limb::from_u32(rhs); + let (w0, carry) = a[0].mac(b[0], Limb::ZERO); + let (w1, carry) = a[1].mac(b[1], carry); + let (w2, carry) = a[2].mac(b[2], carry); + let (w3, carry) = a[3].mac(b[3], carry); + let (w4, carry) = a[4].mac(b[4], carry); + let (w5, carry) = a[5].mac(b[5], carry); + let (w6, carry) = a[6].mac(b[6], carry); + let (w7, w8) = a[7].mac(b[7], carry); + // Attempt to subtract the modulus, to ensure the result is in the field. + let modulus = MODULUS.0.as_limbs(); + + let (result, _) = sub_inner( + [w0, w1, w2, w3, w4, w5, w6, w7, w8], + [ + modulus[0], + modulus[1], + modulus[2], + modulus[3], + modulus[4], + modulus[5], + modulus[6], + modulus[7], + Limb::ZERO, + ], + ); + U256::new([ + result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], + ]) } /// Returns self * rhs mod p diff --git a/p256/src/arithmetic/field/field64.rs b/p256/src/arithmetic/field/field64.rs index f2e74729..a9d5d9c8 100644 --- a/p256/src/arithmetic/field/field64.rs +++ b/p256/src/arithmetic/field/field64.rs @@ -26,11 +26,21 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { /// Multiplies by a single-limb integer.P /// Multiplies the magnitude by the same value. pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { - let mut result = U256::ZERO; - for _i in 0..rhs { - result = add(result, a) - } - result + let a_limbs = a.as_limbs(); + let rhs_limb = Limb::from_u32(rhs); + let (w0, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); + let (w2, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); + let (w3, w4) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); + + // Attempt to subtract the modulus from carry, to ensure the result is in the field + let modulus = MODULUS.0.as_limbs(); + + let (result, _) = sub_inner( + [w0, w1, w2, w3, w4], + [modulus[0], modulus[1], modulus[2], modulus[3], Limb::ZERO], + ); + U256::new([result[0], result[1], result[2], result[3]]) } /// Returns self * rhs mod p From 3493d1dd9b8e7a1cf2f45b7164d55277bd8f85cf Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Wed, 7 Aug 2024 10:35:44 +0800 Subject: [PATCH 09/10] fix mul single --- p256/src/arithmetic/field/field32.rs | 26 +++++++------------------- p256/src/arithmetic/field/field64.rs | 13 ++++++------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/p256/src/arithmetic/field/field32.rs b/p256/src/arithmetic/field/field32.rs index 11facbcc..fcb984b9 100644 --- a/p256/src/arithmetic/field/field32.rs +++ b/p256/src/arithmetic/field/field32.rs @@ -52,26 +52,14 @@ pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { let (w5, carry) = a[5].mac(b[5], carry); let (w6, carry) = a[6].mac(b[6], carry); let (w7, w8) = a[7].mac(b[7], carry); - // Attempt to subtract the modulus, to ensure the result is in the field. - let modulus = MODULUS.0.as_limbs(); - let (result, _) = sub_inner( - [w0, w1, w2, w3, w4, w5, w6, w7, w8], - [ - modulus[0], - modulus[1], - modulus[2], - modulus[3], - modulus[4], - modulus[5], - modulus[6], - modulus[7], - Limb::ZERO, - ], - ); - U256::new([ - result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7], - ]) + // Reduce the carry mod prime + let carry = U256::from(w8); + let (reduced_carry, _) = carry.const_rem(&MODULUS.0); + + // Modular addition of non-carry and reduced carry + let non_carries = U256::new([w0, w1, w2, w3, w4, w5, w6, w7]); + add(non_carries, reduced_carry) } /// Returns self * rhs mod p diff --git a/p256/src/arithmetic/field/field64.rs b/p256/src/arithmetic/field/field64.rs index a9d5d9c8..d6f24f3b 100644 --- a/p256/src/arithmetic/field/field64.rs +++ b/p256/src/arithmetic/field/field64.rs @@ -33,14 +33,13 @@ pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { let (w2, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); let (w3, w4) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); - // Attempt to subtract the modulus from carry, to ensure the result is in the field - let modulus = MODULUS.0.as_limbs(); + // Reduce the carry mod prime + let carry = U256::from(w4); + let (reduced_carry, _) = carry.const_rem(&MODULUS.0); - let (result, _) = sub_inner( - [w0, w1, w2, w3, w4], - [modulus[0], modulus[1], modulus[2], modulus[3], Limb::ZERO], - ); - U256::new([result[0], result[1], result[2], result[3]]) + // Modular addition of non-carry and reduced carry + let non_carries = U256::new([w0, w1, w2, w3]); + add(non_carries, reduced_carry) } /// Returns self * rhs mod p From 3ae63bd91f4cbac73873a659be84c42a2da4792d Mon Sep 17 00:00:00 2001 From: Thia Su Mian Date: Mon, 12 Aug 2024 10:56:47 +0800 Subject: [PATCH 10/10] add more fixes to mul single --- p256/src/arithmetic/field/field32.rs | 43 ++++++++++++++++++++-------- p256/src/arithmetic/field/field64.rs | 35 +++++++++++++++++----- 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/p256/src/arithmetic/field/field32.rs b/p256/src/arithmetic/field/field32.rs index fcb984b9..cf472c21 100644 --- a/p256/src/arithmetic/field/field32.rs +++ b/p256/src/arithmetic/field/field32.rs @@ -44,24 +44,43 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { let a_limbs = a.as_limbs(); let rhs_limb = Limb::from_u32(rhs); - let (w0, carry) = a[0].mac(b[0], Limb::ZERO); - let (w1, carry) = a[1].mac(b[1], carry); - let (w2, carry) = a[2].mac(b[2], carry); - let (w3, carry) = a[3].mac(b[3], carry); - let (w4, carry) = a[4].mac(b[4], carry); - let (w5, carry) = a[5].mac(b[5], carry); - let (w6, carry) = a[6].mac(b[6], carry); - let (w7, w8) = a[7].mac(b[7], carry); - - // Reduce the carry mod prime - let carry = U256::from(w8); - let (reduced_carry, _) = carry.const_rem(&MODULUS.0); + let (w0, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(a_limbs[1], rhs_limb, carry); + let (w2, carry) = Limb::ZERO.mac(a_limbs[2], rhs_limb, carry); + let (w3, cary) = Limb::ZERO.mac(a_limbs[3], rhs_limb, carry); + let (w4, carry) = Limb::ZERO.mac(a_limbs[4], rhs_limb, carry); + let (w5, carry) = Limb::ZERO.mac(a_limbs[5], rhs_limb, carry); + let (w6, carry) = Limb::ZERO.mac(a_limbs[6], rhs_limb, carry); + let (w7, w8) = Limb::ZERO.mac(a_limbs[7], rhs_limb, carry); + + // Define 2^256 - MODULUS + let subtracted_result_str: &str = + "00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001"; + + let subtracted_result = U256::from_be_hex(subtracted_result_str); + // Calculate w8 << 2^256 = w8 * (w^256 - MODULUS) + let reduced_carry = mul_inner(subtracted_result, w8); // Modular addition of non-carry and reduced carry let non_carries = U256::new([w0, w1, w2, w3, w4, w5, w6, w7]); add(non_carries, reduced_carry) } +fn mul_inner(a: U256, b: Limb) -> U256 { + let a_limbs = a.as_limbs(); + let (w0, carry) = Limb::ZERO.mac(a_limbs[0], b, Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(a_limbs[1], b, carry); + let (w2, carry) = Limb::ZERO.mac(a_limbs[2], b, carry); + let (w3, cary) = Limb::ZERO.mac(a_limbs[3], b, carry); + let (w4, carry) = Limb::ZERO.mac(a_limbs[4], b, carry); + let (w5, carry) = Limb::ZERO.mac(a_limbs[5], b, carry); + let (w6, carry) = Limb::ZERO.mac(a_limbs[6], b, carry); + // We can ignore the last carry + let (w7, _) = Limb::ZERO.mac(a_limbs[7], b, carry); + + U256::new([w0, w1, w2, w3, w4, w5, w6, w7]) +} + /// Returns self * rhs mod p pub(super) const fn mul(a: U256, b: U256) -> U256 { let (lo, hi): (U256, U256) = a.mul_wide(&b); diff --git a/p256/src/arithmetic/field/field64.rs b/p256/src/arithmetic/field/field64.rs index d6f24f3b..f9ba9e24 100644 --- a/p256/src/arithmetic/field/field64.rs +++ b/p256/src/arithmetic/field/field64.rs @@ -23,25 +23,46 @@ pub(super) const fn add(a: U256, b: U256) -> U256 { U256::new([result[0], result[1], result[2], result[3]]) } -/// Multiplies by a single-limb integer.P +/// Multiplies by a single-limb integer. /// Multiplies the magnitude by the same value. pub(super) fn mul_single(a: U256, rhs: u32) -> U256 { let a_limbs = a.as_limbs(); let rhs_limb = Limb::from_u32(rhs); let (w0, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, Limb::ZERO); - let (w1, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); - let (w2, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); - let (w3, w4) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry); + let (w1, carry) = Limb::ZERO.mac(a_limbs[1], rhs_limb, carry); + let (w2, carry) = Limb::ZERO.mac(a_limbs[2], rhs_limb, carry); + let (w3, w4) = Limb::ZERO.mac(a_limbs[3], rhs_limb, carry); - // Reduce the carry mod prime - let carry = U256::from(w4); - let (reduced_carry, _) = carry.const_rem(&MODULUS.0); + // Define 2^256 - MODULUS (224 bits) + let subtracted_result_str: &str = + "00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001"; + + let subtracted_result = U256::from_be_hex(subtracted_result_str); + // w4 << 2^256 is equals to w4 * (2^256 - MODULUS) + let reduced_carry = mul_inner(subtracted_result, w4); // Modular addition of non-carry and reduced carry let non_carries = U256::new([w0, w1, w2, w3]); add(non_carries, reduced_carry) } +fn mul_inner(a: U256, b: Limb) -> U256 { + let a_limbs = a.as_limbs(); + let (w0, carry) = Limb::ZERO.mac(a_limbs[0], b, Limb::ZERO); + let (w1, carry) = Limb::ZERO.mac(a_limbs[1], b, carry); + let (w2, carry) = Limb::ZERO.mac(a_limbs[2], b, carry); + let (w3, w4) = Limb::ZERO.mac(a_limbs[3], b, carry); + let non_carries = U256::new([w0, w1, w2, w3]); + + let (c0, carry) = Limb::ZERO.mac(a_limbs[0], w4, Limb::ZERO); + let (c1, carry) = Limb::ZERO.mac(a_limbs[1], w4, carry); + let (c2, carry) = Limb::ZERO.mac(a_limbs[2], w4, carry); + let (c3, _) = Limb::ZERO.mac(a_limbs[3], w4, carry); + let reduced_carry = U256::new([c0, c1, c2, c3]); + + add(non_carries, reduced_carry) +} + /// Returns self * rhs mod p pub(super) fn mul(a: U256, b: U256) -> U256 { let (lo, hi) = a.mul_wide(&b);