From e185711b6ba8f3e22f2af8bf24a5fc84b781ca46 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 23 Sep 2023 00:26:14 -0700 Subject: [PATCH] Sync with upstream PSE (#7) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add field conversion to/from `[u64;4]` (#80) * feat: add field conversion to/from `[u64;4]` * Added conversion tests * Added `montgomery_reduce_short` for no-asm * For bn256, uses assembly conversion when asm feature is on * fix: remove conflict for asm * chore: bump rust-toolchain to 1.67.0 * Compute Legendre symbol for `hash_to_curve` (#77) * Add `Legendre` trait and macro - Add Legendre macro with norm and legendre symbol computation - Add macro for automatic implementation in prime fields * Add legendre macro call for prime fields * Remove unused imports * Remove leftover * Add `is_quadratic_non_residue` for hash_to_curve * Add `legendre` function * Compute modulus separately * Substitute division for shift * Update modulus computation * Add quadratic residue check func * Add quadratic residue tests * Add hash_to_curve bench * Implement Legendre trait for all curves * Move misplaced comment * Add all curves to hash bench * fix: add suggestion for legendre_exp * fix: imports after rebase * Add simplified SWU method (#81) * Fix broken link * Add simple SWU algorithm * Add simplified SWU hash_to_curve for secp256r1 * add: sswu z reference * update MAP_ID identifier Co-authored-by: Han --------- Co-authored-by: Han * Bring back curve algorithms for `a = 0` (#82) * refactor: bring back curve algorithms for `a = 0` * fix: clippy warning * fix: Improve serialization for prime fields (#85) * fix: Improve serialization for prime fields Summary: 256-bit field serialization is currently 4x u64, ie. the native format. This implements the standard of byte-serialization (corresponding to the PrimeField::{to,from}_repr), and an hex-encoded variant of that for (de)serializers that are human-readable (concretely, json). - Added a new macro `serialize_deserialize_32_byte_primefield!` for custom serialization and deserialization of 32-byte prime field in different struct (Fq, Fp, Fr) across the secp256r, bn256, and derive libraries. - Implemented the new macro for serialization and deserialization in various structs, replacing the previous `serde::{Deserialize, Serialize}` direct use. - Enhanced error checking in the custom serialization methods to ensure valid field elements. - Updated the test function in the tests/field.rs file to include JSON serialization and deserialization tests for object integrity checking. * fixup! fix: Improve serialization for prime fields --------- Co-authored-by: Carlos Pérez <37264926+CPerezz@users.noreply.github.com> * refactor: (De)Serialization of points using `GroupEncoding` (#88) * refactor: implement (De)Serialization of points using the `GroupEncoding` trait - Updated curve point (de)serialization logic from the internal representation to the representation offered by the implementation of the `GroupEncoding` trait. * fix: add explicit json serde tests * Insert MSM and FFT code and their benchmarks. (#86) * Insert MSM and FFT code and their benchmarks. Resolves taikoxyz/zkevm-circuits#150. * feedback * Add instructions * feeback * Implement feedback: Actually supply the correct arguments to `best_multiexp`. Split into `singlecore` and `multicore` benchmarks so Criterion's result caching and comparison over multiple runs makes sense. Rewrite point and scalar generation. * Use slicing and parallelism to to decrease running time. Laptop measurements: k=22: 109 sec k=16: 1 sec * Refactor msm * Refactor fft * Update module comments * Fix formatting * Implement suggestion for fixing CI --------- Co-authored-by: David Nevado Co-authored-by: Han Co-authored-by: François Garillot <4142+huitseeker@users.noreply.github.com> Co-authored-by: Carlos Pérez <37264926+CPerezz@users.noreply.github.com> Co-authored-by: einar-taiko <126954546+einar-taiko@users.noreply.github.com> --- Cargo.toml | 21 +++++- benches/fft.rs | 57 +++++++++++++++ benches/group.rs | 16 ++-- benches/hash_to_curve.rs | 59 +++++++++++++++ benches/msm.rs | 116 +++++++++++++++++++++++++++++ src/bn256/fq.rs | 40 +++------- src/bn256/fq2.rs | 56 +++++++------- src/bn256/fr.rs | 17 +++-- src/bn256/mod.rs | 7 -- src/derive/curve.rs | 79 ++++++++++++++++++-- src/derive/field.rs | 35 +++++++++ src/fft.rs | 134 ++++++++++++++++++++++++++++++++++ src/hash_to_curve.rs | 138 +++++++++++++++++++++++++++++++++-- src/legendre.rs | 50 +++++++++++++ src/lib.rs | 5 ++ src/msm.rs | 153 +++++++++++++++++++++++++++++++++++++++ src/multicore.rs | 16 ++++ src/pasta/mod.rs | 9 +++ src/secp256k1/fp.rs | 14 +++- src/secp256k1/fq.rs | 13 +++- src/secp256r1/curve.rs | 124 ++++++++++++++++++------------- src/secp256r1/fp.rs | 14 +++- src/secp256r1/fq.rs | 15 ++-- src/tests/curve.rs | 17 ++++- src/tests/field.rs | 23 +++++- 25 files changed, 1066 insertions(+), 162 deletions(-) create mode 100644 benches/fft.rs create mode 100644 benches/hash_to_curve.rs create mode 100644 benches/msm.rs create mode 100644 src/fft.rs create mode 100644 src/legendre.rs create mode 100644 src/msm.rs create mode 100644 src/multicore.rs diff --git a/Cargo.toml b/Cargo.toml index 817c80b6..7b0aa1b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ criterion = { version = "0.3", features = ["html_reports"] } rand_xorshift = "0.3" ark-std = { version = "0.3" } bincode = "1.3.3" +serde_json = "1.0.105" [dependencies] subtle = "2.4" @@ -30,14 +31,17 @@ num-traits = "0.2" paste = "1.0.11" serde = { version = "1.0", default-features = false, optional = true } serde_arrays = { version = "0.1.0", optional = true } +hex = { version = "0.4", optional = true, default-features = false, features = ["alloc", "serde"] } blake2b_simd = "1" +maybe-rayon = { version = "0.1.0", default-features = false } [features] -default = ["reexport", "bits", "bn256-table", "derive_serde"] +default = ["reexport", "bits", "multicore", "bn256-table", "derive_serde"] +multicore = ["maybe-rayon/threads"] asm = [] bits = ["ff/bits"] bn256-table = [] -derive_serde = ["serde/derive", "serde_arrays"] +derive_serde = ["serde/derive", "serde_arrays", "hex"] prefetch = [] print-trace = ["ark-std/print-trace"] reexport = [] @@ -63,3 +67,16 @@ required-features = ["reexport"] [[bench]] name = "group" harness = false + +[[bench]] +name = "hash_to_curve" +harness = false + +[[bench]] +name = "fft" +harness = false + +[[bench]] +name = "msm" +harness = false +required-features = ["multicore"] diff --git a/benches/fft.rs b/benches/fft.rs new file mode 100644 index 00000000..a250308d --- /dev/null +++ b/benches/fft.rs @@ -0,0 +1,57 @@ +//! This benchmarks Fast-Fourier Transform (FFT). +//! Since it is over a finite field, it is actually the Number Theoretical +//! Transform (NNT). It uses the `Fr` scalar field from the BN256 curve. +//! +//! To run this benchmark: +//! +//! cargo bench -- fft +//! +//! Caveat: The multicore benchmark assumes: +//! 1. a multi-core system +//! 2. that the `multicore` feature is enabled. It is by default. + +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion}; +use group::ff::Field; +use halo2curves::bn256::Fr as Scalar; +use halo2curves::fft::best_fft; +use rand_core::OsRng; +use std::ops::Range; +use std::time::SystemTime; + +const RANGE: Range = 3..19; + +fn generate_data(k: u32) -> Vec { + let n = 1 << k; + let timer = SystemTime::now(); + println!("\n\nGenerating 2^{k} = {n} values..",); + let data: Vec = (0..n).map(|_| Scalar::random(OsRng)).collect(); + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} values took: {} sec.\n\n", + end.as_secs() + ); + data +} + +fn fft(c: &mut Criterion) { + let max_k = RANGE.max().unwrap_or(16); + let mut data = generate_data(max_k); + let omega = Scalar::random(OsRng); + let mut group = c.benchmark_group("fft"); + for k in RANGE { + group.bench_function(BenchmarkId::new("k", k), |b| { + let n = 1 << k; + assert!(n <= data.len()); + b.iter(|| { + best_fft(&mut data[..n], omega, k); + }); + }); + } + group.finish(); +} + +criterion_group!(benches, fft); +criterion_main!(benches); diff --git a/benches/group.rs b/benches/group.rs index 68cfee53..b1936e68 100644 --- a/benches/group.rs +++ b/benches/group.rs @@ -18,28 +18,28 @@ fn criterion_benchmark(c: &mut Criterion) { let v = vec![G::generator(); N]; let mut q = vec![G::AffineExt::identity(); N]; - c.bench_function(&format!("{} check on curve", name), move |b| { + c.bench_function(&format!("{name} check on curve"), move |b| { b.iter(|| black_box(p1).is_on_curve()) }); - c.bench_function(&format!("{} check equality", name), move |b| { + c.bench_function(&format!("{name} check equality"), move |b| { b.iter(|| black_box(p1) == black_box(p1)) }); - c.bench_function(&format!("{} to affine", name), move |b| { + c.bench_function(&format!("{name} to affine"), move |b| { b.iter(|| G::AffineExt::from(black_box(p1))) }); - c.bench_function(&format!("{} doubling", name), move |b| { + c.bench_function(&format!("{name} doubling"), move |b| { b.iter(|| black_box(p1).double()) }); - c.bench_function(&format!("{} addition", name), move |b| { + c.bench_function(&format!("{name} addition"), move |b| { b.iter(|| black_box(p1).add(&p2)) }); - c.bench_function(&format!("{} mixed addition", name), move |b| { + c.bench_function(&format!("{name} mixed addition"), move |b| { b.iter(|| black_box(p2).add(&p1_affine)) }); - c.bench_function(&format!("{} scalar multiplication", name), move |b| { + c.bench_function(&format!("{name} scalar multiplication"), move |b| { b.iter(|| black_box(p1) * black_box(s)) }); - c.bench_function(&format!("{} batch to affine n={}", name, N), move |b| { + c.bench_function(&format!("{name} batch to affine n={N}"), move |b| { b.iter(|| { G::batch_normalize(black_box(&v), black_box(&mut q)); black_box(&q)[0] diff --git a/benches/hash_to_curve.rs b/benches/hash_to_curve.rs new file mode 100644 index 00000000..bda1c1d3 --- /dev/null +++ b/benches/hash_to_curve.rs @@ -0,0 +1,59 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use pasta_curves::arithmetic::CurveExt; +use rand_core::{OsRng, RngCore}; +use std::iter; + +fn hash_to_secp256k1(c: &mut Criterion) { + hash_to_curve::(c, "Secp256k1"); +} + +fn hash_to_secq256k1(c: &mut Criterion) { + hash_to_curve::(c, "Secq256k1"); +} + +fn hash_to_secp256r1(c: &mut Criterion) { + hash_to_curve::(c, "Secp256r1"); +} + +fn hash_to_pallas(c: &mut Criterion) { + hash_to_curve::(c, "Pallas"); +} + +fn hash_to_vesta(c: &mut Criterion) { + hash_to_curve::(c, "Vesta"); +} + +fn hash_to_bn256(c: &mut Criterion) { + hash_to_curve::(c, "Bn256"); +} + +fn hash_to_grumpkin(c: &mut Criterion) { + hash_to_curve::(c, "Grumpkin"); +} + +fn hash_to_curve(c: &mut Criterion, name: &'static str) { + { + let hasher = G::hash_to_curve("test"); + let mut rng = OsRng; + let message = iter::repeat_with(|| rng.next_u32().to_be_bytes()) + .take(1024) + .flatten() + .collect::>(); + + c.bench_function(&format!("Hash to {name}"), move |b| { + b.iter(|| hasher(black_box(&message))) + }); + } +} + +criterion_group!( + benches, + hash_to_secp256k1, + hash_to_secq256k1, + hash_to_secp256r1, + hash_to_pallas, + hash_to_vesta, + hash_to_bn256, + hash_to_grumpkin, +); +criterion_main!(benches); diff --git a/benches/msm.rs b/benches/msm.rs new file mode 100644 index 00000000..c78952b7 --- /dev/null +++ b/benches/msm.rs @@ -0,0 +1,116 @@ +//! This benchmarks Multi Scalar Multiplication (MSM). +//! It measures `G1` from the BN256 curve. +//! +//! To run this benchmark: +//! +//! cargo bench -- msm +//! +//! Caveat: The multicore benchmark assumes: +//! 1. a multi-core system +//! 2. that the `multicore` feature is enabled. It is by default. + +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion}; +use ff::Field; +use group::prime::PrimeCurveAffine; +use halo2curves::bn256::{Fr as Scalar, G1Affine as Point}; +use halo2curves::msm::{best_multiexp, multiexp_serial}; +use maybe_rayon::current_thread_index; +use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use rand_core::SeedableRng; +use rand_xorshift::XorShiftRng; +use std::time::SystemTime; + +const SAMPLE_SIZE: usize = 10; +const SINGLECORE_RANGE: [u8; 6] = [3, 8, 10, 12, 14, 16]; +const MULTICORE_RANGE: [u8; 9] = [3, 8, 10, 12, 14, 16, 18, 20, 22]; +const SEED: [u8; 16] = [ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5, +]; + +fn generate_coefficients_and_curvepoints(k: u8) -> (Vec, Vec) { + let n: u64 = { + assert!(k < 64); + 1 << k + }; + + println!("\n\nGenerating 2^{k} = {n} coefficients and curve points..",); + let timer = SystemTime::now(); + let coeffs = (0..n) + .into_par_iter() + .map_init( + || { + let mut thread_seed = SEED; + let uniq = current_thread_index().unwrap().to_ne_bytes(); + assert!(std::mem::size_of::() == 8); + for i in 0..uniq.len() { + thread_seed[i] += uniq[i]; + thread_seed[i + 8] += uniq[i]; + } + XorShiftRng::from_seed(thread_seed) + }, + |rng, _| Scalar::random(rng), + ) + .collect(); + let bases = (0..n) + .into_par_iter() + .map_init( + || { + let mut thread_seed = SEED; + let uniq = current_thread_index().unwrap().to_ne_bytes(); + assert!(std::mem::size_of::() == 8); + for i in 0..uniq.len() { + thread_seed[i] += uniq[i]; + thread_seed[i + 8] += uniq[i]; + } + XorShiftRng::from_seed(thread_seed) + }, + |rng, _| Point::random(rng), + ) + .collect(); + let end = timer.elapsed().unwrap(); + println!( + "Generating 2^{k} = {n} coefficients and curve points took: {} sec.\n\n", + end.as_secs() + ); + + (coeffs, bases) +} + +fn msm(c: &mut Criterion) { + let mut group = c.benchmark_group("msm"); + let max_k = *SINGLECORE_RANGE + .iter() + .chain(MULTICORE_RANGE.iter()) + .max() + .unwrap_or(&16); + let (coeffs, bases) = generate_coefficients_and_curvepoints(max_k); + + for k in SINGLECORE_RANGE { + group + .bench_function(BenchmarkId::new("singlecore", k), |b| { + assert!(k < 64); + let n: usize = 1 << k; + let mut acc = Point::identity().into(); + b.iter(|| multiexp_serial(&coeffs[..n], &bases[..n], &mut acc)); + }) + .sample_size(10); + } + for k in MULTICORE_RANGE { + group + .bench_function(BenchmarkId::new("multicore", k), |b| { + assert!(k < 64); + let n: usize = 1 << k; + b.iter(|| { + best_multiexp(&coeffs[..n], &bases[..n]); + }) + }) + .sample_size(SAMPLE_SIZE); + } + group.finish(); +} + +criterion_group!(benches, msm); +criterion_main!(benches); diff --git a/src/bn256/fq.rs b/src/bn256/fq.rs index 951de3b4..56be690f 100644 --- a/src/bn256/fq.rs +++ b/src/bn256/fq.rs @@ -1,11 +1,10 @@ #[cfg(feature = "asm")] use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] -use crate::{field_arithmetic, field_specific}; +use crate::{arithmetic::macx, field_arithmetic, field_specific}; -use crate::arithmetic::{adc, mac, macx, sbb}; -use crate::bn256::LegendreSymbol; -use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use crate::arithmetic::{adc, mac, sbb}; +use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, impl_binops_additive_specify_output, impl_binops_multiplicative, @@ -17,9 +16,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `p = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47` @@ -29,9 +25,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 const MODULUS: Fq = Fq([ @@ -160,27 +158,10 @@ impl Fq { pub const fn size() -> usize { 32 } - - pub fn legendre(&self) -> LegendreSymbol { - // s = self^((modulus - 1) // 2) - // 0x183227397098d014dc2822db40c0ac2ecbc0b548b438e5469e10460b6c3e7ea3 - let s = &[ - 0x9e10460b6c3e7ea3u64, - 0xcbc0b548b438e546u64, - 0xdc2822db40c0ac2eu64, - 0x183227397098d014u64, - ]; - let s = self.pow(s); - if s == Self::zero() { - LegendreSymbol::Zero - } else if s == Self::one() { - LegendreSymbol::QuadraticResidue - } else { - LegendreSymbol::QuadraticNonResidue - } - } } +prime_field_legendre!(Fq); + impl ff::Field for Fq { const ZERO: Self = Self::zero(); const ONE: Self = Self::one(); @@ -303,6 +284,7 @@ impl WithSmallOrderMulGroup<3> for Fq { #[cfg(test)] mod test { use super::*; + use crate::legendre::Legendre; use ff::Field; use rand_core::OsRng; @@ -315,7 +297,7 @@ mod test { let a = Fq::random(OsRng); let mut b = a; b = b.square(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); let b = b.sqrt().unwrap(); let mut negb = b; @@ -328,7 +310,7 @@ mod test { for _ in 0..10000 { let mut b = c; b = b.square(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); b = b.sqrt().unwrap(); diff --git a/src/bn256/fq2.rs b/src/bn256/fq2.rs index e5a249ee..66d2c6a7 100644 --- a/src/bn256/fq2.rs +++ b/src/bn256/fq2.rs @@ -1,6 +1,6 @@ use super::fq::{Fq, NEGATIVE_ONE}; -use super::LegendreSymbol; use crate::ff::{Field, FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; +use crate::legendre::Legendre; use core::convert::TryInto; use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; @@ -125,6 +125,30 @@ impl_binops_additive!(Fq2, Fq2); impl_binops_multiplicative!(Fq2, Fq2); impl_sum_prod!(Fq2); +impl Legendre for Fq2 { + type BasePrimeField = Fq; + + #[inline] + fn legendre_exp() -> &'static [u64] { + lazy_static::lazy_static! { + // (p-1) / 2 + static ref LEGENDRE_EXP: Vec = + (num_bigint::BigUint::from_bytes_le((-::ONE).to_repr().as_ref())/2usize).to_u64_digits(); + } + &LEGENDRE_EXP + } + + /// Norm of Fq2 as extension field in i over Fq + #[inline] + fn norm(&self) -> Self::BasePrimeField { + let mut t0 = self.c0; + let mut t1 = self.c1; + t0 = t0.square(); + t1 = t1.square(); + t1 + t0 + } +} + impl Fq2 { #[inline] pub const fn zero() -> Fq2 { @@ -174,10 +198,6 @@ impl Fq2 { res } - pub fn legendre(&self) -> LegendreSymbol { - self.norm().legendre() - } - pub fn mul_assign(&mut self, other: &Self) { let mut t1 = self.c0 * other.c0; let mut t0 = self.c0 + self.c1; @@ -298,15 +318,6 @@ impl Fq2 { self.c1 += &t0; } - /// Norm of Fq2 as extension field in i over Fq - pub fn norm(&self) -> Fq { - let mut t0 = self.c0; - let mut t1 = self.c1; - t0 = t0.square(); - t1 = t1.square(); - t1 + t0 - } - pub fn invert(&self) -> CtOption { let mut t1 = self.c1; t1 = t1.square(); @@ -696,17 +707,6 @@ fn test_fq2_mul_nonresidue() { } } -#[test] -fn test_fq2_legendre() { - assert_eq!(LegendreSymbol::Zero, Fq2::ZERO.legendre()); - // i^2 = -1 - let mut m1 = Fq2::ONE; - m1 = m1.neg(); - assert_eq!(LegendreSymbol::QuadraticResidue, m1.legendre()); - m1.mul_by_nonresidue(); - assert_eq!(LegendreSymbol::QuadraticNonResidue, m1.legendre()); -} - #[test] pub fn test_sqrt() { let mut rng = XorShiftRng::from_seed([ @@ -716,7 +716,7 @@ pub fn test_sqrt() { for _ in 0..10000 { let a = Fq2::random(&mut rng); - if a.legendre() == LegendreSymbol::QuadraticNonResidue { + if a.legendre() == -Fq::ONE { assert!(bool::from(a.sqrt().is_none())); } } @@ -725,7 +725,7 @@ pub fn test_sqrt() { let a = Fq2::random(&mut rng); let mut b = a; b.square_assign(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); let b = b.sqrt().unwrap(); let mut negb = b; @@ -738,7 +738,7 @@ pub fn test_sqrt() { for _ in 0..10000 { let mut b = c; b.square_assign(); - assert_eq!(b.legendre(), LegendreSymbol::QuadraticResidue); + assert_eq!(b.legendre(), Fq::ONE); b = b.sqrt().unwrap(); diff --git a/src/bn256/fr.rs b/src/bn256/fr.rs index 3db131bb..9e02fcdc 100644 --- a/src/bn256/fr.rs +++ b/src/bn256/fr.rs @@ -1,7 +1,7 @@ #[cfg(feature = "asm")] use crate::bn256::assembly::field_arithmetic_asm; #[cfg(not(feature = "asm"))] -use crate::{field_arithmetic, field_specific}; +use crate::{arithmetic::macx, field_arithmetic, field_specific}; #[cfg(feature = "bn256-table")] #[rustfmt::skip] @@ -18,7 +18,7 @@ pub use table::FR_TABLE; #[cfg(not(feature = "bn256-table"))] use crate::impl_from_u64; -use crate::arithmetic::{adc, mac, macx, sbb}; +use crate::arithmetic::{adc, mac, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; use crate::{ field_bits, field_common, impl_add_binop_specify_output, impl_binops_additive, @@ -31,9 +31,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_r$ where /// /// `r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001` @@ -43,9 +40,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fr` values are always in // Montgomery form; i.e., Fr(a) = aR mod r, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fr(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fr); + /// Constant representing the modulus /// r = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 const MODULUS: Fr = Fr([ @@ -166,6 +165,7 @@ field_common!( R3 ); impl_sum_prod!(Fr); +prime_field_legendre!(Fr); #[cfg(not(feature = "bn256-table"))] impl_from_u64!(Fr, R2); @@ -463,4 +463,9 @@ mod test { end_timer!(timer); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/bn256/mod.rs b/src/bn256/mod.rs index 9cd08946..3530b765 100644 --- a/src/bn256/mod.rs +++ b/src/bn256/mod.rs @@ -16,10 +16,3 @@ pub use fq12::*; pub use fq2::*; pub use fq6::*; pub use fr::*; - -#[derive(Debug, PartialEq, Eq)] -pub enum LegendreSymbol { - Zero = 0, - QuadraticResidue = 1, - QuadraticNonResidue = -1, -} diff --git a/src/derive/curve.rs b/src/derive/curve.rs index 552b0b84..467c1be0 100644 --- a/src/derive/curve.rs +++ b/src/derive/curve.rs @@ -288,8 +288,72 @@ macro_rules! new_curve_impl { } + /// A macro to help define point serialization using the [`group::GroupEncoding`] trait + /// This assumes both point types ($name, $nameaffine) implement [`group::GroupEncoding`]. + #[cfg(feature = "derive_serde")] + macro_rules! serialize_deserialize_to_from_bytes { + () => { + impl ::serde::Serialize for $name { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_bytes(); + if serializer.is_human_readable() { + ::hex::serde::serialize(&bytes.0, serializer) + } else { + ::serde_arrays::serialize(&bytes.0, serializer) + } + } + } + + paste::paste! { + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $name { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + ::serde_arrays::deserialize::<_, u8, [< $name _COMPRESSED_SIZE >]>(deserializer)? + }; + Option::from(Self::from_bytes(&[< $name Compressed >](bytes))).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + } + + impl ::serde::Serialize for $name_affine { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_bytes(); + if serializer.is_human_readable() { + ::hex::serde::serialize(&bytes.0, serializer) + } else { + ::serde_arrays::serialize(&bytes.0, serializer) + } + } + } + + paste::paste! { + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $name_affine { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + ::serde_arrays::deserialize::<_, u8, [< $name _COMPRESSED_SIZE >]>(deserializer)? + }; + Option::from(Self::from_bytes(&[< $name Compressed >](bytes))).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + } + }; + } + #[derive(Copy, Clone, Debug)] - #[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] $($privacy)* struct $name { pub x: $base, pub y: $base, @@ -297,13 +361,13 @@ macro_rules! new_curve_impl { } #[derive(Copy, Clone, PartialEq)] - #[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] $($privacy)* struct $name_affine { pub x: $base, pub y: $base, } - + #[cfg(feature = "derive_serde")] + serialize_deserialize_to_from_bytes!(); impl_compressed!(); impl_uncompressed!(); @@ -473,17 +537,18 @@ macro_rules! new_curve_impl { fn is_on_curve(&self) -> Choice { if $constant_a == $base::ZERO { // Check (Y/Z)^2 = (X/Z)^3 + b - // <=> Z Y^2 - X^3 = Z^3 b + // <=> Z Y^2 - X^3 = Z^3 b (self.z * self.y.square() - self.x.square() * self.x) .ct_eq(&(self.z.square() * self.z * $constant_b)) | self.z.is_zero() } else { // Check (Y/Z)^2 = (X/Z)^3 + a(X/Z) + b - // <=> Z Y^2 - X^3 - a(X Z^2) = Z^3 b + // <=> Z Y^2 - X^3 - a(X Z^2) = Z^3 b - (self.z * self.y.square() - self.x.square() * self.x - $constant_a * self.x * self.z.square()) - .ct_eq(&(self.z.square() * self.z * $constant_b)) + let z2 = self.z.square(); + (self.z * self.y.square() - (self.x.square() + $constant_a * z2) * self.x) + .ct_eq(&(z2 * self.z * $constant_b)) | self.z.is_zero() } } diff --git a/src/derive/field.rs b/src/derive/field.rs index e2bd8111..bdef8606 100644 --- a/src/derive/field.rs +++ b/src/derive/field.rs @@ -702,3 +702,38 @@ macro_rules! field_bits { } }; } + +/// A macro to help define serialization and deserialization for prime field implementations +/// that use 32-byte representations. This assumes the concerned type implements PrimeField +/// (for from_repr, to_repr). +#[macro_export] +macro_rules! serialize_deserialize_32_byte_primefield { + ($type:ty) => { + impl ::serde::Serialize for $type { + fn serialize(&self, serializer: S) -> Result { + let bytes = &self.to_repr(); + if serializer.is_human_readable() { + hex::serde::serialize(bytes, serializer) + } else { + bytes.serialize(serializer) + } + } + } + + use ::serde::de::Error as _; + impl<'de> ::serde::Deserialize<'de> for $type { + fn deserialize>( + deserializer: D, + ) -> Result { + let bytes = if deserializer.is_human_readable() { + ::hex::serde::deserialize(deserializer)? + } else { + <[u8; 32]>::deserialize(deserializer)? + }; + Option::from(Self::from_repr(bytes)).ok_or_else(|| { + D::Error::custom("deserialized bytes don't encode a valid field element") + }) + } + } + }; +} diff --git a/src/fft.rs b/src/fft.rs new file mode 100644 index 00000000..6eb3487e --- /dev/null +++ b/src/fft.rs @@ -0,0 +1,134 @@ +use crate::multicore; +pub use crate::{CurveAffine, CurveExt}; +use ff::Field; +use group::{GroupOpsOwned, ScalarMulOwned}; + +/// This represents an element of a group with basic operations that can be +/// performed. This allows an FFT implementation (for example) to operate +/// generically over either a field or elliptic curve group. +pub trait FftGroup: + Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned +{ +} + +impl FftGroup for T +where + Scalar: Field, + T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned, +{ +} + +/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size +/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative +/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when +/// interpreted as the coefficients of a polynomial of degree $n - 1$, is +/// transformed into the evaluations of this polynomial at each of the $n$ +/// distinct powers of $\omega$. This transformation is invertible by providing +/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element +/// by $n$. +/// +/// This will use multithreading if beneficial. +pub fn best_fft>(a: &mut [G], omega: Scalar, log_n: u32) { + fn bitreverse(mut n: usize, l: usize) -> usize { + let mut r = 0; + for _ in 0..l { + r = (r << 1) | (n & 1); + n >>= 1; + } + r + } + + let threads = multicore::current_num_threads(); + let log_threads = threads.ilog2(); + let n = a.len(); + assert_eq!(n, 1 << log_n); + + for k in 0..n { + let rk = bitreverse(k, log_n as usize); + if k < rk { + a.swap(rk, k); + } + } + + // precompute twiddle factors + let twiddles: Vec<_> = (0..(n / 2)) + .scan(Scalar::ONE, |w, _| { + let tw = *w; + *w *= ω + Some(tw) + }) + .collect(); + + if log_n <= log_threads { + let mut chunk = 2_usize; + let mut twiddle_chunk = n / 2; + for _ in 0..log_n { + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + }); + chunk *= 2; + twiddle_chunk /= 2; + } + } else { + recursive_butterfly_arithmetic(a, n, 1, &twiddles) + } +} + +/// This perform recursive butterfly arithmetic +pub fn recursive_butterfly_arithmetic>( + a: &mut [G], + n: usize, + twiddle_chunk: usize, + twiddles: &[Scalar], +) { + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0] += &t; + a[1] -= &t; + } else { + let (left, right) = a.split_at_mut(n / 2); + multicore::join( + || recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles), + || recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0] += &t; + b[0] -= &t; + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t *= &twiddles[(i + 1) * twiddle_chunk]; + *b = *a; + *a += &t; + *b -= &t; + }); + } +} diff --git a/src/hash_to_curve.rs b/src/hash_to_curve.rs index 4cef7095..22251102 100644 --- a/src/hash_to_curve.rs +++ b/src/hash_to_curve.rs @@ -3,7 +3,9 @@ use ff::{Field, FromUniformBytes, PrimeField}; use pasta_curves::arithmetic::CurveExt; use static_assertions::const_assert; -use subtle::{ConditionallySelectable, ConstantTimeEq}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; + +use crate::legendre::Legendre; /// Hashes over a message and writes the output to all of `buf`. /// Modified from https://github.com/zcash/pasta_curves/blob/7e3fc6a4919f6462a32b79dd226cb2587b7961eb/src/hashtocurve.rs#L11. @@ -83,6 +85,94 @@ fn hash_to_field>( } } +// Implementation of +#[allow(clippy::too_many_arguments)] +pub(crate) fn simple_svdw_map_to_curve(u: C::Base, z: C::Base) -> C +where + C: CurveExt, +{ + let zero = C::Base::ZERO; + let one = C::Base::ONE; + let a = C::a(); + let b = C::b(); + + //1. tv1 = u^2 + let tv1 = u.square(); + //2. tv1 = Z * tv1 + let tv1 = z * tv1; + //3. tv2 = tv1^2 + let tv2 = tv1.square(); + //4. tv2 = tv2 + tv1 + let tv2 = tv2 + tv1; + //5. tv3 = tv2 + 1 + let tv3 = tv2 + one; + //6. tv3 = B * tv3 + let tv3 = b * tv3; + //7. tv4 = CMOV(Z, -tv2, tv2 != 0) # tv4 = z if tv2 is 0 else tv4 = -tv2 + let tv2_is_not_zero = !tv2.ct_eq(&zero); + let tv4 = C::Base::conditional_select(&z, &-tv2, tv2_is_not_zero); + //8. tv4 = A * tv4 + let tv4 = a * tv4; + //9. tv2 = tv3^2 + let tv2 = tv3.square(); + //10. tv6 = tv4^2 + let tv6 = tv4.square(); + //11. tv5 = A * tv6 + let tv5 = a * tv6; + //12. tv2 = tv2 + tv5 + let tv2 = tv2 + tv5; + //13. tv2 = tv2 * tv3 + let tv2 = tv2 * tv3; + //14. tv6 = tv6 * tv4 + let tv6 = tv6 * tv4; + //15. tv5 = B * tv6 + let tv5 = b * tv6; + //16. tv2 = tv2 + tv5 + let tv2 = tv2 + tv5; + //17. x = tv1 * tv3 + let x = tv1 * tv3; + //18. (is_gx1_square, y1) = sqrt_ratio(tv2, tv6) + let (is_gx1_square, y1) = sqrt_ratio(&tv2, &tv6, &z); + //19. y = tv1 * u + let y = tv1 * u; + //20. y = y * y1 + let y = y * y1; + //21. x = CMOV(x, tv3, is_gx1_square) + let x = C::Base::conditional_select(&x, &tv3, is_gx1_square); + //22. y = CMOV(y, y1, is_gx1_square) + let y = C::Base::conditional_select(&y, &y1, is_gx1_square); + //23. e1 = sgn0(u) == sgn0(y) + let e1 = u.is_odd().ct_eq(&y.is_odd()); + //24. y = CMOV(-y, y, e1) # Select correct sign of y + let y = C::Base::conditional_select(&-y, &y, e1); + //25. x = x / tv4 + let x = x * tv4.invert().unwrap(); + //26. return (x, y) + C::new_jacobian(x, y, one).unwrap() +} + +#[allow(clippy::type_complexity)] +pub(crate) fn simple_svdw_hash_to_curve<'a, C>( + curve_id: &'static str, + domain_prefix: &'a str, + z: C::Base, +) -> Box C + 'a> +where + C: CurveExt, + C::Base: FromUniformBytes<64>, +{ + Box::new(move |message| { + let mut us = [C::Base::ZERO; 2]; + hash_to_field("SSWU", curve_id, domain_prefix, message, &mut us); + + let [q0, q1]: [C; 2] = us.map(|u| simple_svdw_map_to_curve(u, z)); + + let r = q0 + &q1; + debug_assert!(bool::from(r.is_on_curve())); + r + }) +} + #[allow(clippy::too_many_arguments)] pub(crate) fn svdw_map_to_curve( u: C::Base, @@ -94,6 +184,7 @@ pub(crate) fn svdw_map_to_curve( ) -> C where C: CurveExt, + C::Base: Legendre, { let one = C::Base::ONE; let a = C::a(); @@ -128,7 +219,7 @@ where // 14. gx1 = gx1 + B let gx1 = gx1 + b; // 15. e1 = is_square(gx1) - let e1 = gx1.sqrt().is_some(); + let e1 = !gx1.ct_quadratic_non_residue(); // 16. x2 = c2 + tv4 let x2 = c2 + tv4; // 17. gx2 = x2^2 @@ -140,7 +231,7 @@ where // 20. gx2 = gx2 + B let gx2 = gx2 + b; // 21. e2 = is_square(gx2) AND NOT e1 # Avoid short-circuit logic ops - let e2 = gx2.sqrt().is_some() & (!e1); + let e2 = !gx2.ct_quadratic_non_residue() & (!e1); // 22. x3 = tv2^2 let x3 = tv2.square(); // 23. x3 = x3 * tv3 @@ -173,7 +264,44 @@ where C::new_jacobian(x, y, one).unwrap() } -/// Implementation of https://www.ietf.org/id/draft-irtf-cfrg-hash-to-curve-16.html#name-shallue-van-de-woestijne-met +// Implement https://datatracker.ietf.org/doc/html/rfc9380#name-sqrt_ratio-for-any-field +// Copied from ff sqrt_ratio_generic subsituting F::ROOT_OF_UNITY for input Z +fn sqrt_ratio(num: &F, div: &F, z: &F) -> (Choice, F) { + // General implementation: + // + // a = num * inv0(div) + // = { 0 if div is zero + // { num/div otherwise + // + // b = z * a + // = { 0 if div is zero + // { z*num/div otherwise + + // Since z is non-square, a and b are either both zero (and both square), or + // only one of them is square. We can therefore choose the square root to return + // based on whether a is square, but for the boolean output we need to handle the + // num != 0 && div == 0 case specifically. + + let a = div.invert().unwrap_or(F::ZERO) * num; + let b = a * z; + let sqrt_a = a.sqrt(); + let sqrt_b = b.sqrt(); + + let num_is_zero = num.is_zero(); + let div_is_zero = div.is_zero(); + let is_square = sqrt_a.is_some(); + let is_nonsquare = sqrt_b.is_some(); + assert!(bool::from( + num_is_zero | div_is_zero | (is_square ^ is_nonsquare) + )); + + ( + is_square & (num_is_zero | !div_is_zero), + CtOption::conditional_select(&sqrt_b, &sqrt_a, is_square).unwrap(), + ) +} + +/// Implementation of https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.html#section-6.6.1 #[allow(clippy::type_complexity)] pub(crate) fn svdw_hash_to_curve<'a, C>( curve_id: &'static str, @@ -182,7 +310,7 @@ pub(crate) fn svdw_hash_to_curve<'a, C>( ) -> Box C + 'a> where C: CurveExt, - C::Base: FromUniformBytes<64>, + C::Base: FromUniformBytes<64> + Legendre, { let [c1, c2, c3, c4] = svdw_precomputed_constants::(z); diff --git a/src/legendre.rs b/src/legendre.rs new file mode 100644 index 00000000..7e4b9971 --- /dev/null +++ b/src/legendre.rs @@ -0,0 +1,50 @@ +use ff::{Field, PrimeField}; +use subtle::{Choice, ConstantTimeEq}; + +pub trait Legendre: Field { + type BasePrimeField: PrimeField; + + // This is (p-1)/2 where p is the modulus of the base prime field + fn legendre_exp() -> &'static [u64]; + + fn norm(&self) -> Self::BasePrimeField; + + #[inline] + fn legendre(&self) -> Self::BasePrimeField { + self.norm().pow(Self::legendre_exp()) + } + + #[inline] + fn ct_quadratic_residue(&self) -> Choice { + self.legendre().ct_eq(&Self::BasePrimeField::ONE) + } + + #[inline] + fn ct_quadratic_non_residue(&self) -> Choice { + self.legendre().ct_eq(&-Self::BasePrimeField::ONE) + } +} + +#[macro_export] +macro_rules! prime_field_legendre { + ($field:ident ) => { + impl $crate::legendre::Legendre for $field { + type BasePrimeField = Self; + + #[inline] + fn legendre_exp() -> &'static [u64] { + lazy_static::lazy_static! { + // (p-1) / 2 + static ref LEGENDRE_EXP: Vec = + (num_bigint::BigUint::from_bytes_le((-<$field as ff::Field>::ONE).to_repr().as_ref())/2usize).to_u64_digits(); + } + &*LEGENDRE_EXP + } + + #[inline] + fn norm(&self) -> Self::BasePrimeField { + self.clone() + } + } + }; +} diff --git a/src/lib.rs b/src/lib.rs index 44aa63e2..f5bcea52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,11 @@ mod arithmetic; mod bernsteinyang; +pub mod fft; pub mod hash_to_curve; +pub mod msm; +pub mod multicore; +#[macro_use] +pub mod legendre; pub mod serde; pub mod bn256; diff --git a/src/msm.rs b/src/msm.rs new file mode 100644 index 00000000..de30be55 --- /dev/null +++ b/src/msm.rs @@ -0,0 +1,153 @@ +use ff::PrimeField; +use group::Group; +use pasta_curves::arithmetic::CurveAffine; + +use crate::multicore; + +pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + fn get_at(segment: usize, c: usize, bytes: &F::Repr) -> usize { + let skip_bits = segment * c; + let skip_bytes = skip_bits / 8; + + if skip_bytes >= 32 { + return 0; + } + + let mut v = [0; 8]; + for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *v = *o; + } + + let mut tmp = u64::from_le_bytes(v); + tmp >>= skip_bits - (skip_bytes * 8); + tmp %= 1 << c; + + tmp as usize + } + + let segments = (256 / c) + 1; + + for current_segment in (0..segments).rev() { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; (1 << c) - 1]; + + for (coeff, base) in coeffs.iter().zip(bases.iter()) { + let coeff = get_at::(current_segment, c, coeff); + if coeff != 0 { + buckets[coeff - 1].add_assign(base); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + +/// Performs a small multi-exponentiation operation. +/// Uses the double-and-add algorithm with doublings shared across points. +pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + let mut acc = C::Curve::identity(); + + // for byte idx + for byte_idx in (0..32).rev() { + // for bit idx + for bit_idx in (0..8).rev() { + acc = acc.double(); + // for each coeff + for coeff_idx in 0..coeffs.len() { + let byte = coeffs[coeff_idx].as_ref()[byte_idx]; + if ((byte >> bit_idx) & 1) != 0 { + acc += bases[coeff_idx]; + } + } + } + } + + acc +} + +/// Performs a multi-exponentiation operation. +/// +/// This function will panic if coeffs and bases have a different length. +/// +/// This will use multithreading if beneficial. +pub fn best_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); + + let num_threads = multicore::current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![C::Curve::identity(); num_chunks]; + multicore::scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + multiexp_serial(coeffs, bases, acc); + }); + } + }); + results.iter().fold(C::Curve::identity(), |a, b| a + b) + } else { + let mut acc = C::Curve::identity(); + multiexp_serial(coeffs, bases, &mut acc); + acc + } +} diff --git a/src/multicore.rs b/src/multicore.rs new file mode 100644 index 00000000..d8323553 --- /dev/null +++ b/src/multicore.rs @@ -0,0 +1,16 @@ +pub use maybe_rayon::{ + iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, + join, scope, Scope, +}; + +#[cfg(feature = "multicore")] +pub use maybe_rayon::{ + current_num_threads, + iter::{IndexedParallelIterator, IntoParallelRefIterator}, + slice::ParallelSliceMut, +}; + +#[cfg(not(feature = "multicore"))] +pub fn current_num_threads() -> usize { + 1 +} diff --git a/src/pasta/mod.rs b/src/pasta/mod.rs index 164697b5..078b663e 100644 --- a/src/pasta/mod.rs +++ b/src/pasta/mod.rs @@ -38,6 +38,9 @@ const ENDO_PARAMS_EP: EndoParameters = EndoParameters { endo!(Eq, Fp, ENDO_PARAMS_EQ); endo!(Ep, Fq, ENDO_PARAMS_EP); +prime_field_legendre!(Fp); +prime_field_legendre!(Fq); + #[test] fn test_endo() { use ff::Field; @@ -71,3 +74,9 @@ fn test_endo() { } } } + +#[test] +fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + crate::tests::field::random_quadratic_residue_test::(); +} diff --git a/src/secp256k1/fp.rs b/src/secp256k1/fp.rs index fdbfd16f..db496559 100644 --- a/src/secp256k1/fp.rs +++ b/src/secp256k1/fp.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_p$ where /// /// `p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f` @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fp` values are always in // Montgomery form; i.e., Fp(a) = aR mod p, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fp(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fp); + /// Constant representing the modulus /// p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f const MODULUS: Fp = Fp([ @@ -288,6 +287,8 @@ impl WithSmallOrderMulGroup<3> for Fp { const ZETA: Self = ZETA; } +prime_field_legendre!(Fp); + #[cfg(test)] mod test { use super::*; @@ -360,4 +361,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256k1 base".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256k1/fq.rs b/src/secp256k1/fq.rs index 86ffcd9e..c6cb8e06 100644 --- a/src/secp256k1/fq.rs +++ b/src/secp256k1/fq.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `q = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141` @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 const MODULUS: Fq = Fq([ @@ -295,6 +294,8 @@ impl WithSmallOrderMulGroup<3> for Fq { const ZETA: Self = ZETA; } +prime_field_legendre!(Fq); + #[cfg(test)] mod test { use super::*; @@ -367,4 +368,8 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256k1 scalar".to_string()); } + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256r1/curve.rs b/src/secp256r1/curve.rs index 8a21f664..c23aa782 100644 --- a/src/secp256r1/curve.rs +++ b/src/secp256r1/curve.rs @@ -1,6 +1,7 @@ use crate::ff::WithSmallOrderMulGroup; use crate::ff::{Field, PrimeField}; use crate::group::{prime::PrimeCurveAffine, Curve, Group as _, GroupEncoding}; +use crate::hash_to_curve::simple_svdw_hash_to_curve; use crate::secp256r1::Fp; use crate::secp256r1::Fq; use crate::{Coordinates, CurveAffine, CurveAffineExt, CurveExt}; @@ -75,77 +76,98 @@ new_curve_impl!( SECP_A, SECP_B, "secp256r1", - |_, _| unimplemented!(), + |curve_id, domain_prefix| simple_svdw_hash_to_curve(curve_id, domain_prefix, Secp256r1::SSVDW_Z), ); -#[test] -fn test_curve() { - crate::tests::curve::curve_tests::(); +impl Secp256r1 { + // Optimal Z with: + // 0xffffffff00000001000000000000000000000000fffffffffffffffffffffff5 + // Z = -10 (reference: ) + const SSVDW_Z: Fp = Fp::from_raw([ + 0xfffffffffffffff5, + 0x00000000ffffffff, + 0x0000000000000000, + 0xffffffff00000001, + ]); } -#[test] -fn test_serialization() { - crate::tests::curve::random_serialization_test::(); - #[cfg(feature = "derive_serde")] - crate::tests::curve::random_serde_test::(); -} - -#[test] -fn ecdsa_example() { +#[cfg(test)] +mod tests { + use super::*; use crate::group::Curve; - use crate::CurveAffine; + use crate::secp256r1::{Fp, Fq, Secp256r1}; use ff::FromUniformBytes; use rand_core::OsRng; - fn mod_n(x: Fp) -> Fq { - let mut x_repr = [0u8; 32]; - x_repr.copy_from_slice(x.to_repr().as_ref()); - let mut x_bytes = [0u8; 64]; - x_bytes[..32].copy_from_slice(&x_repr[..]); - Fq::from_uniform_bytes(&x_bytes) + #[test] + fn test_hash_to_curve() { + crate::tests::curve::hash_to_curve_test::(); + } + + #[test] + fn test_curve() { + crate::tests::curve::curve_tests::(); } - let g = Secp256r1::generator(); + #[test] + fn test_serialization() { + crate::tests::curve::random_serialization_test::(); + #[cfg(feature = "derive_serde")] + crate::tests::curve::random_serde_test::(); + } + + #[test] + fn ecdsa_example() { + fn mod_n(x: Fp) -> Fq { + let mut x_repr = [0u8; 32]; + x_repr.copy_from_slice(x.to_repr().as_ref()); + let mut x_bytes = [0u8; 64]; + x_bytes[..32].copy_from_slice(&x_repr[..]); + Fq::from_uniform_bytes(&x_bytes) + } + + let g = Secp256r1::generator(); - for _ in 0..1000 { - // Generate a key pair - let sk = Fq::random(OsRng); - let pk = (g * sk).to_affine(); + for _ in 0..1000 { + // Generate a key pair + let sk = Fq::random(OsRng); + let pk = (g * sk).to_affine(); - // Generate a valid signature - // Suppose `m_hash` is the message hash - let msg_hash = Fq::random(OsRng); + // Generate a valid signature + // Suppose `m_hash` is the message hash + let msg_hash = Fq::random(OsRng); - let (r, s) = { - // Draw arandomness - let k = Fq::random(OsRng); - let k_inv = k.invert().unwrap(); + let (r, s) = { + // Draw arandomness + let k = Fq::random(OsRng); + let k_inv = k.invert().unwrap(); - // Calculate `r` - let r_point = (g * k).to_affine().coordinates().unwrap(); - let x = r_point.x(); - let r = mod_n(*x); + // Calculate `r` + let r_point = (g * k).to_affine().coordinates().unwrap(); + let x = r_point.x(); + let r = mod_n(*x); - // Calculate `s` - let s = k_inv * (msg_hash + (r * sk)); + // Calculate `s` + let s = k_inv * (msg_hash + (r * sk)); - (r, s) - }; + (r, s) + }; - { - // Verify - let s_inv = s.invert().unwrap(); - let u_1 = msg_hash * s_inv; - let u_2 = r * s_inv; + { + // Verify + let s_inv = s.invert().unwrap(); + let u_1 = msg_hash * s_inv; + let u_2 = r * s_inv; - let v_1 = g * u_1; - let v_2 = pk * u_2; + let v_1 = g * u_1; + let v_2 = pk * u_2; - let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); - let x_candidate = r_point.x(); - let r_candidate = mod_n(*x_candidate); + let r_point = (v_1 + v_2).to_affine().coordinates().unwrap(); + let x_candidate = r_point.x(); + let r_candidate = mod_n(*x_candidate); - assert_eq!(r, r_candidate); + assert_eq!(r, r_candidate); + } } } } diff --git a/src/secp256r1/fp.rs b/src/secp256r1/fp.rs index 61a85be9..331545c3 100644 --- a/src/secp256r1/fp.rs +++ b/src/secp256r1/fp.rs @@ -11,9 +11,6 @@ use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_p$ where /// /// `p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff @@ -23,9 +20,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fp` values are always in // Montgomery form; i.e., Fp(a) = aR mod p, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fp(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fp); + /// Constant representing the modulus /// p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff const MODULUS: Fp = Fp([ @@ -306,6 +305,8 @@ impl WithSmallOrderMulGroup<3> for Fp { const ZETA: Self = ZETA; } +prime_field_legendre!(Fp); + #[cfg(test)] mod test { use super::*; @@ -378,4 +379,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256r1 base".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/secp256r1/fq.rs b/src/secp256r1/fq.rs index ef34d952..077ec331 100644 --- a/src/secp256r1/fq.rs +++ b/src/secp256r1/fq.rs @@ -1,14 +1,10 @@ use crate::arithmetic::{adc, mac, macx, sbb}; use crate::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup}; -use core::convert::TryInto; use core::fmt; use core::ops::{Add, Mul, Neg, Sub}; use rand::RngCore; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; -#[cfg(feature = "derive_serde")] -use serde::{Deserialize, Serialize}; - /// This represents an element of $\mathbb{F}_q$ where /// /// `q = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551` @@ -18,9 +14,11 @@ use serde::{Deserialize, Serialize}; // integers in little-endian order. `Fq` values are always in // Montgomery form; i.e., Fq(a) = aR mod q, with R = 2^256. #[derive(Clone, Copy, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "derive_serde", derive(Serialize, Deserialize))] pub struct Fq(pub(crate) [u64; 4]); +#[cfg(feature = "derive_serde")] +crate::serialize_deserialize_32_byte_primefield!(Fq); + /// Constant representing the modulus /// q = 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551 const MODULUS: Fq = Fq([ @@ -291,6 +289,8 @@ impl WithSmallOrderMulGroup<3> for Fq { const ZETA: Self = ZETA; } +prime_field_legendre!(Fq); + #[cfg(test)] mod test { use super::*; @@ -363,4 +363,9 @@ mod test { #[cfg(feature = "derive_serde")] crate::tests::field::random_serde_test::("secp256r1 scalar".to_string()); } + + #[test] + fn test_quadratic_residue() { + crate::tests::field::random_quadratic_residue_test::(); + } } diff --git a/src/tests/curve.rs b/src/tests/curve.rs index 9bb0fe0b..2f93bbb4 100644 --- a/src/tests/curve.rs +++ b/src/tests/curve.rs @@ -2,6 +2,7 @@ use crate::ff::Field; use crate::group::prime::PrimeCurveAffine; +use crate::legendre::Legendre; use crate::tests::fe_from_str; use crate::{group::GroupEncoding, serde::SerdeObject}; use crate::{hash_to_curve, CurveAffine, CurveExt}; @@ -73,12 +74,24 @@ where assert_eq!(projective_point.to_affine(), affine_point_rec); assert_eq!(affine_point, affine_point_rec); } + { + let affine_json = serde_json::to_string(&affine_point).unwrap(); + let reader = std::io::Cursor::new(affine_json); + let affine_point_rec: G::AffineExt = serde_json::from_reader(reader).unwrap(); + assert_eq!(affine_point, affine_point_rec); + } { let projective_bytes = bincode::serialize(&projective_point).unwrap(); let reader = std::io::Cursor::new(projective_bytes); let projective_point_rec: G = bincode::deserialize_from(reader).unwrap(); assert_eq!(projective_point, projective_point_rec); } + { + let projective_json = serde_json::to_string(&projective_point).unwrap(); + let reader = std::io::Cursor::new(projective_json); + let projective_point_rec: G = serde_json::from_reader(reader).unwrap(); + assert_eq!(projective_point, projective_point_rec); + } } } @@ -343,7 +356,9 @@ pub fn svdw_map_to_curve_test( z: G::Base, precomputed_constants: [&'static str; 4], test_vector: impl IntoIterator, -) { +) where + ::Base: Legendre, +{ let [c1, c2, c3, c4] = hash_to_curve::svdw_precomputed_constants::(z); assert_eq!([c1, c2, c3, c4], precomputed_constants.map(fe_from_str)); for (u, (x, y)) in test_vector.into_iter() { diff --git a/src/tests/field.rs b/src/tests/field.rs index a064441e..02f5509f 100644 --- a/src/tests/field.rs +++ b/src/tests/field.rs @@ -1,6 +1,7 @@ -use crate::ff::Field; use crate::serde::SerdeObject; +use crate::{ff::Field, legendre::Legendre}; use ark_std::{end_timer, start_timer}; +use ff::PrimeField; use rand::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; @@ -279,11 +280,31 @@ where let _message = format!("serialization with serde {type_name}"); let start = start_timer!(|| _message); for _ in 0..1000000 { + // byte serialization let a = F::random(&mut rng); let bytes = bincode::serialize(&a).unwrap(); let reader = std::io::Cursor::new(bytes); let b: F = bincode::deserialize_from(reader).unwrap(); assert_eq!(a, b); + + // json serialization + let json = serde_json::to_string(&a).unwrap(); + let reader = std::io::Cursor::new(json); + let b: F = serde_json::from_reader(reader).unwrap(); + assert_eq!(a, b); } end_timer!(start); } + +pub fn random_quadratic_residue_test() { + let mut rng = XorShiftRng::from_seed([ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, + 0xe5, + ]); + for _ in 0..100000 { + let elem = F::random(&mut rng); + let is_quad_res_or_zero: bool = elem.sqrt().is_some().into(); + let is_quad_non_res: bool = elem.ct_quadratic_non_residue().into(); + assert_eq!(!is_quad_non_res, is_quad_res_or_zero) + } +}