diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bf75df26..f28e8181 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,6 +10,7 @@ on: # Make sure CI fails on all warnings, including Clippy lints env: RUSTFLAGS: "-Dwarnings" + RUSTDOCFLAGS: "-Dwarnings" jobs: @@ -17,25 +18,33 @@ jobs: name: Rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly - override: true components: rustfmt - - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check + - run: cargo fmt --all -- --check clippy_check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Clippy - run: cargo clippy --all-targets --all-features + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo clippy --all-targets --all-features --tests + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@1.60.0 + - uses: Swatinem/rust-cache@v2 + - run: | + cargo update -p proptest --precise "1.2.0" + cargo update -p tempfile --precise "3.3.0" + - run: cargo tree --all-features # to debug deps issues + - run: cargo build --release --all-features # We want to test stable on multiple platforms with --all-features test: @@ -45,11 +54,11 @@ jobs: matrix: target: ["x86_64-unknown-linux-gnu", "armv7-unknown-linux-gnueabihf"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: 1.60.0 + toolchain: stable target: ${{ matrix.target }} override: true - uses: Swatinem/rust-cache@v2.0.0 @@ -65,18 +74,10 @@ jobs: test-nightly: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - target: "x86_64-unknown-linux-gnu" - override: true - - uses: Swatinem/rust-cache@v2.0.0 - - uses: actions-rs/cargo@v1 - with: - command: test - args: --release --all-features + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + - uses: Swatinem/rust-cache@v2 + - run: cargo test --release --all-features # test without default features test-minimal: @@ -85,18 +86,10 @@ jobs: matrix: package: [ "secp256kfun", "sigma_fun", "ecdsa_fun", "schnorr_fun" ] steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - target: "x86_64-unknown-linux-gnu" - override: true + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2.0.0 - - uses: actions-rs/cargo@v1 - with: - command: test - args: --release --no-default-features -p ${{ matrix.package }} + - run: cargo test --release --no-default-features -p ${{ matrix.package }} # test with alloc feature only @@ -106,30 +99,16 @@ jobs: matrix: package: [ "secp256kfun", "sigma_fun", "ecdsa_fun", "schnorr_fun" ] steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - target: "x86_64-unknown-linux-gnu" - override: true + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2.0.0 - - uses: actions-rs/cargo@v1 - with: - command: test - args: --release --no-default-features --features alloc -p ${{ matrix.package }} + - run: cargo test --release --no-default-features --features alloc -p ${{ matrix.package }} doc-build: name: doc-build runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true - - name: build-doc - # convoluted way to make it fail on warnings - run: "cargo doc --no-deps --workspace 2>&1 | tee /dev/fd/2 | grep -iEq '^(warning|error)' && exit 1 || exit 0" + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - run: cargo doc --no-deps --workspace diff --git a/CHANGELOG.md b/CHANGELOG.md index dad2b75a..c4053ac5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ # CHANGELOG +## Unreleased + +- Added `arithmetic_macros` to make `g!` and `s!` macros into procedural macros +- Made even `Secret` things `Copy`. See discussion [here](https://github.com/LLFourn/secp256kfun/issues/6#issuecomment-1363752651). + ## v0.9.1 - Added more `bincode` derives for FROST things diff --git a/Cargo.toml b/Cargo.toml index 7572966a..1a3bdb11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "secp256kfun", "schnorr_fun", "ecdsa_fun", - "sigma_fun" + "sigma_fun", + "arithmetic_macros" ] resolver = "2" diff --git a/arithmetic_macros/Cargo.toml b/arithmetic_macros/Cargo.toml new file mode 100644 index 00000000..2ca5baf7 --- /dev/null +++ b/arithmetic_macros/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "secp256kfun_arithmetic_macros" +version = "0.9.0" +edition = "2021" + +[lib] +proc-macro = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +proc-macro2 = "1" +quote = "1" diff --git a/arithmetic_macros/src/lib.rs b/arithmetic_macros/src/lib.rs new file mode 100644 index 00000000..0114f94a --- /dev/null +++ b/arithmetic_macros/src/lib.rs @@ -0,0 +1,168 @@ +mod optree; +use optree::{Infix, InfixKind, Node, OpTree}; +use proc_macro::TokenStream; +use proc_macro2::{Ident, TokenTree}; +use quote::{quote, quote_spanned}; +use std::iter::Peekable; +type Input = Peekable; + +#[proc_macro] +pub fn gen_s(input: TokenStream) -> TokenStream { + let input: proc_macro2::TokenStream = input.into(); + let mut iter = input.into_iter().peekable(); + + let path = match iter.next() { + Some(TokenTree::Ident(path)) => path, + _ => panic!("put the path to secpfun crate first"), + }; + let optree = match optree::parse_tokens(&mut iter) { + Ok(optree) => optree, + Err(e) => { + let problem = e.problem; + return quote_spanned!(e.span => compile_error!(#problem)).into(); + } + }; + + compile_s(&path, optree).into() +} + +fn compile_s(path: &Ident, node: Node) -> proc_macro2::TokenStream { + match *node.tree { + OpTree::Infix(Infix { lhs, rhs, kind }) => { + let lhs_ = compile_s(path, lhs); + let mut rhs_ = compile_s(path, rhs); + let fn_name = Ident::new( + match kind { + InfixKind::Add => "scalar_add", + InfixKind::Mul => "scalar_mul", + InfixKind::Sub => "scalar_sub", + InfixKind::LinComb => "scalar_dot_product", + InfixKind::Div => { + rhs_ = quote_spanned! { node.span => #path::op::scalar_invert(#rhs_) }; + "scalar_mul" + } + }, + node.span, + ); + + quote_spanned! { node.span => #path::op::#fn_name(#lhs_, #rhs_) } + } + OpTree::Unary(unary) => match unary.kind { + optree::UnaryKind::Neg => { + let fn_name = Ident::new("scalar_negate", node.span); + let subj = compile_s(path, unary.subj); + quote_spanned! { node.span => #path::op::#fn_name(#subj) } + } + optree::UnaryKind::Ref => { + let a = unary.punct; + let subj = compile_g(path, unary.subj); + quote!( #a #subj ) + } + }, + OpTree::Term(ts) => ts, + OpTree::Paren(node) => compile_s(path, node), + OpTree::LitInt(lit_int) => { + if lit_int == 0 { + quote_spanned! { node.span => #path::Scalar::<#path::marker::Secret, _>::zero() } + } else { + quote_spanned! { node.span => + #path::Scalar::<#path::marker::Secret, #path::marker::NonZero>::from_non_zero_u32(unsafe { + core::num::NonZeroU32::new_unchecked(#lit_int) + }) + } + } + } + } +} + +#[proc_macro] +pub fn gen_g(input: TokenStream) -> TokenStream { + let input: proc_macro2::TokenStream = input.into(); + let mut iter = input.into_iter().peekable(); + + let path = match iter.next() { + Some(TokenTree::Ident(path)) => path, + _ => panic!("put the path to secpfun crate first"), + }; + let node = match optree::parse_tokens(&mut iter) { + Ok(optree) => optree, + Err(e) => { + let problem = e.problem; + return quote_spanned!(e.span => compile_error!(#problem)).into(); + } + }; + + compile_g(&path, node).into() +} + +fn compile_g(path: &Ident, node: Node) -> proc_macro2::TokenStream { + match *node.tree { + OpTree::Infix(Infix { lhs, rhs, kind }) => match kind { + InfixKind::Add | InfixKind::Sub => { + let is_sub = kind == InfixKind::Sub; + match (&*lhs.tree, &*rhs.tree) { + ( + OpTree::Infix(Infix { + kind: InfixKind::Mul, + lhs: llhs, + rhs: lrhs, + }), + OpTree::Infix(Infix { + kind: InfixKind::Mul, + lhs: rlhs, + rhs: rrhs, + }), + ) => { + let llhs = compile_s(path, llhs.clone()); + let lrhs = compile_g(path, lrhs.clone()); + let mut rlhs = compile_s(path, rlhs.clone()); + let rrhs = compile_g(path, rrhs.clone()); + if is_sub { + rlhs = quote_spanned! { node.span => #path::op::scalar_negate(#rlhs) }; + } + quote_spanned! { node.span => #path::op::double_mul(#llhs, #lrhs, #rlhs, #rrhs) } + } + (..) => { + let lhs_ = compile_g(path, lhs); + let rhs_ = compile_g(path, rhs); + if is_sub { + quote_spanned! { node.span => #path::op::point_sub(#lhs_, #rhs_) } + } else { + quote_spanned! { node.span => #path::op::point_add(#lhs_, #rhs_) } + } + } + } + } + InfixKind::Mul => { + let lhs_ = compile_s(path, lhs); + let rhs_ = compile_g(path, rhs); + quote_spanned! { node.span => #path::op::scalar_mul_point(#lhs_, #rhs_) } + } + InfixKind::LinComb => { + let lhs_ = compile_s(path, lhs); + let rhs_ = compile_g(path, rhs); + quote_spanned! { node.span => #path::op::point_scalar_dot_product(#lhs_, #rhs_) } + } + InfixKind::Div => { + quote_spanned! { node.span => compile_error!("can't use division in group expression") } + } + }, + OpTree::Term(term) => term, + OpTree::Paren(node) => compile_g(path, node), + OpTree::Unary(unary) => match unary.kind { + optree::UnaryKind::Neg => { + let fn_name = Ident::new("point_negate", node.span); + let subj = compile_g(path, unary.subj); + quote_spanned! { node.span => #path::op::#fn_name(#subj) } + } + optree::UnaryKind::Ref => { + let a = unary.punct; + let subj = compile_g(path, unary.subj); + quote!( #a #subj ) + } + }, + OpTree::LitInt(lit_int) => { + quote_spanned! { node.span => compile_error!("can't use literal int {} in group expression", #lit_int)} + } + } +} diff --git a/arithmetic_macros/src/optree.rs b/arithmetic_macros/src/optree.rs new file mode 100644 index 00000000..e350777e --- /dev/null +++ b/arithmetic_macros/src/optree.rs @@ -0,0 +1,571 @@ +#![allow(unused)] +use super::Input; +use proc_macro2::{token_stream, Delimiter, Punct, Span, TokenStream, TokenTree}; +use quote::{quote_spanned, ToTokens}; +use std::{fmt::Display, iter::Peekable}; + +#[derive(Clone)] +pub(crate) enum OpTree { + Infix(Infix), + Term(TokenStream), + Paren(Node), + Unary(Unary), + LitInt(u32), +} + +#[derive(Clone)] +pub(crate) struct Node { + pub tree: Box, + pub span: Span, +} + +impl core::fmt::Debug for Node { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.tree.fmt(f) + } +} + +impl core::fmt::Debug for OpTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Infix(infix) => f + .debug_tuple(&infix.kind.to_string()) + .field(&infix.lhs) + .field(&infix.rhs) + .finish(), + Self::Term(arg0) => write!(f, "{}", arg0.to_string().replace(' ', "")), + Self::Paren(arg0) => arg0.fmt(f), + Self::Unary(unary) => f + .debug_tuple(&unary.kind.to_string()) + .field(&unary.subj) + .finish(), + Self::LitInt(arg0) => write!(f, "{}", arg0), + } + } +} + +impl Node { + fn new(tree: OpTree, span: Span) -> Self { + Node { + tree: Box::new(tree), + span, + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct Infix { + pub lhs: Node, + pub rhs: Node, + pub kind: InfixKind, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub(crate) enum InfixKind { + Add, + Mul, + Sub, + LinComb, + Div, +} + +impl InfixKind { + fn precedence(self) -> u8 { + match self { + InfixKind::Add | InfixKind::Sub => 0, + InfixKind::Mul | InfixKind::LinComb | InfixKind::Div => 1, + } + } +} + +impl core::fmt::Display for InfixKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match self { + InfixKind::Add => "+", + InfixKind::Mul => "*", + InfixKind::Sub => "-", + InfixKind::LinComb => ".*", + InfixKind::Div => "/", + }) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct Unary { + pub subj: Node, + pub kind: UnaryKind, + pub punct: Punct, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum UnaryKind { + Neg, + Ref, +} + +impl core::fmt::Display for UnaryKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match self { + UnaryKind::Neg => "-", + UnaryKind::Ref => "&", + }) + } +} + +#[derive(Clone, Debug)] +pub struct Error { + pub span: Span, + pub problem: String, +} + +pub(crate) fn token_stream_to_node(ts: token_stream::TokenStream) -> Result { + parse_tokens(&mut ts.into_iter().peekable()) +} + +pub(crate) fn parse_tokens(input: &mut Input) -> Result { + rule_opchain(input) +} + +fn rule_term(input: &mut Input) -> Result { + let unaries = rule_prefix(input)?; + let next = input + .peek() + .expect("must not be called with an empty input"); + let mut span = next.span(); + let mut optree = match next { + TokenTree::Ident(_) => { + let mut tt = TokenStream::new(); + tt.extend(input.next()); + tt.extend(rule_postfix(input)?); + OpTree::Term(tt) + } + TokenTree::Group(group) => { + let group = group.clone(); + match group.delimiter() { + Delimiter::Parenthesis => { + let _ = input.next(); + OpTree::Paren(token_stream_to_node(group.stream())?) + } + Delimiter::Brace => { + let input: TokenStream = input.next().unwrap().into(); + let term = quote_spanned! { span => #[allow(unused_braces)] #input }; + OpTree::Term(term) + } + _ => { + return Err(Error { + span: group.span(), + problem: "can only use '(..)' or '{..}'".into(), + }) + } + } + } + TokenTree::Literal(lit) => { + let int_lit: u32 = lit.to_string().parse().map_err(|e| Error { + span: lit.span(), + problem: "only u32 literals are supported".into(), + })?; + let _ = input.next(); + OpTree::LitInt(int_lit) + } + tt => { + return Err(Error { + span: tt.span(), + problem: "this is an invalid term".into(), + }) + } + }; + + for (unary_kind, punct) in unaries { + optree = OpTree::Unary(Unary { + kind: unary_kind, + subj: Node::new(optree, span), + punct: punct.clone(), + }); + span = punct.span(); + } + Ok(Node::new(optree, span)) +} + +fn rule_prefix(input: &mut Input) -> Result, Error> { + let mut unaries = vec![]; + while let Some(TokenTree::Punct(punct)) = input.peek() { + match punct.as_char() { + '-' => { + unaries.push((UnaryKind::Neg, punct.to_owned())); + let _ = input.next(); + } + '&' => { + unaries.push((UnaryKind::Ref, punct.to_owned())); + let _ = input.next(); + } + _ => break, + } + } + Ok(unaries) +} +fn rule_postfix(input: &mut Input) -> Result, Error> { + let mut tokens = vec![]; + + loop { + let mut lookahead = input.clone(); + let next = lookahead.next(); + match next { + Some(TokenTree::Punct(punct)) => { + if punct.as_char() == '.' { + let is_dot_product = matches!(lookahead.peek(), Some(TokenTree::Punct(punct)) if punct.as_char() == '*'); + if is_dot_product { + break; + } + + tokens.push(input.next().unwrap()); + + let error = Err(Error { + span: punct.span(), + problem: + "expecting a method call, property access or tuple access after period" + .into(), + }); + // look for .0, .foo or .foo(a,b) + match input.next() { + Some(following_period) => match &following_period { + TokenTree::Ident(_) => { + tokens.push(following_period); + } + TokenTree::Literal(lit) if lit.to_string().parse::().is_ok() => { + tokens.push(following_period); + } + _following_period => return error, + }, + None => return error, + } + } else { + break; + } + } + Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => { + tokens.push(input.next().unwrap()); + } + Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Bracket => { + tokens.push(input.next().unwrap()); + } + _ => break, + } + } + + Ok(tokens) +} + +fn rule_opchain(input: &mut Input) -> Result { + let mut lhs = rule_term(input)?; + + if input.peek().is_none() { + return Ok(lhs); + } + + let (kind, span) = rule_infix_op(input)?; + + let mut rhs = rule_opchain(input)?; + let mut top_node = Node::new(OpTree::Infix(Infix { lhs, rhs, kind }), span); + let mut cursor = &mut top_node; + + while let OpTree::Infix(infix) = *cursor.tree.clone() { + match &*infix.rhs.tree { + OpTree::Infix(rhs_infix) if infix.kind.precedence() >= rhs_infix.kind.precedence() => { + let fixed = Node::new( + OpTree::Infix(Infix { + lhs: Node::new( + OpTree::Infix(Infix { + lhs: infix.lhs.clone(), + rhs: rhs_infix.lhs.clone(), + kind, + }), + span, + ), + rhs: rhs_infix.rhs.clone(), + kind: rhs_infix.kind, + }), + infix.rhs.span, + ); + *cursor = fixed.clone(); + cursor = match &mut *cursor.tree { + OpTree::Infix(infix) => &mut infix.lhs, + _ => unreachable!(), + } + } + _ => break, + } + } + + Ok(top_node) +} + +fn rule_infix_op(input: &mut Input) -> Result<(InfixKind, Span), Error> { + let next = input.next().expect("must not be called on empty input"); + match next { + TokenTree::Punct(punct) => { + let error = Err(Error { + span: punct.span(), + problem: "unknown infix operator".into(), + }); + + let op = match punct.as_char() { + '+' => InfixKind::Add, + '*' => InfixKind::Mul, + '-' => InfixKind::Sub, + '.' => match input.next() { + Some(TokenTree::Punct(star)) if star.as_char() == '*' => InfixKind::LinComb, + _ => return error, + }, + '/' => InfixKind::Div, + _ => return error, + }; + Ok((op, punct.span())) + } + _ => Err(Error { + span: next.span(), + problem: "expecting an infix operator".into(), + }), + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::str::FromStr; + + macro_rules! parse { + ($lit:expr) => { + *match token_stream_to_node(TokenStream::from_str($lit).unwrap()) { + Err(e) => panic!("{}", e.problem), + Ok(expr) => expr, + } + .tree + }; + } + + #[test] + fn test_term() { + assert!(matches!(parse!("a_term"), OpTree::Term(tt) if tt.to_string() == "a_term")); + } + + #[test] + fn add2() { + let ot = parse!("a + b"); + assert!( + matches!(ot, OpTree::Infix (Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Term(tt) if tt.to_string() == "a") && + matches!(&*rhs.tree, OpTree::Term(tt) if tt.to_string() == "b") + ) + ); + } + + #[test] + fn add3() { + let ot = parse!("a + b + c"); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Infix(Infix{ lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Term(a) if a.to_string() == "a") && + matches!(&*rhs.tree, OpTree::Term(b) if b.to_string() == "b") + ) && + matches!(&*rhs.tree, OpTree::Term(c) if c.to_string() == "c") + ) + ); + } + + #[test] + fn add_mul3() { + let ot = parse!("a * A + b * B + c * C"); + dbg!(&ot); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Add }) + if matches!(&*lhs.tree, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Mul })) && + matches!(&*rhs.tree, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Mul })) + ) && + matches!(&*rhs.tree, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Mul }))) + ); + } + + #[test] + fn addparen() { + let ot = parse!("(a + b) + c"); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, kind: InfixKind::Add, .. }) if + matches!(&*lhs.tree, OpTree::Paren(paren) if + matches!(&*paren.tree, OpTree::Infix(Infix { kind: InfixKind::Add, .. })) + )) + ); + } + + #[test] + fn addparen2() { + let ot = parse!("a + (b + c)"); + assert!( + matches!(ot, OpTree::Infix(Infix { rhs, kind: InfixKind::Add, .. }) if + matches!(&*rhs.tree, OpTree::Paren(paren) if + matches!(&*paren.tree, OpTree::Infix(Infix { kind: InfixKind::Add, .. })) + )) + ); + } + + #[test] + fn addmul() { + let ot = parse!("a + b * c"); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Term(a) if a.to_string() == "a") + && + matches!(&*rhs.tree, OpTree::Infix(Infix{ lhs, rhs, kind: InfixKind::Mul }) if + matches!(&*lhs.tree, OpTree::Term(b) if b.to_string() == "b") && + matches!(&*rhs.tree, OpTree::Term(c) if c.to_string() == "c") + ) + ) + ); + } + + #[test] + fn muladd() { + let ot = parse!("a * b + c"); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Infix(Infix{ lhs, rhs, kind: InfixKind::Mul }) if + matches!(&*lhs.tree, OpTree::Term(a) if a.to_string() == "a") && + matches!(&*rhs.tree, OpTree::Term(b) if b.to_string() == "b") + ) && + matches!(&*rhs.tree, OpTree::Term(c) if c.to_string() == "c") + ) + ); + } + + #[test] + fn addsub() { + let ot = parse!("a + b - c"); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Sub }) if + matches!(&*lhs.tree, OpTree::Infix(Infix{ lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Term(a) if a.to_string() == "a") && + matches!(&*rhs.tree, OpTree::Term(b) if b.to_string() == "b") + ) && + matches!(&*rhs.tree, OpTree::Term(c) if c.to_string() == "c") + ) + ); + } + + #[test] + fn subadd() { + let ot = parse!("a - b + c"); + assert!( + matches!(ot, OpTree::Infix(Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Infix(Infix{ lhs, rhs, kind: InfixKind::Sub }) if + matches!(&*lhs.tree, OpTree::Term(a) if a.to_string() == "a") && + matches!(&*rhs.tree, OpTree::Term(b) if b.to_string() == "b") + ) && + matches!(&*rhs.tree, OpTree::Term(c) if c.to_string() == "c") + ) + ); + } + + #[test] + fn unary_negate() { + let ot = parse!("-a"); + assert!( + matches!(ot, OpTree::Unary(Unary { kind: UnaryKind::Neg, subj, .. }) if matches!(&*subj.tree, OpTree::Term(a) if a.to_string() == "a")) + ) + } + + #[test] + fn unary_ref() { + let ot = parse!("&a"); + assert!( + matches!(ot, OpTree::Unary(Unary { kind: UnaryKind::Ref, subj, .. }) if matches!(&*subj.tree, OpTree::Term(a) if a.to_string() == "a")) + ) + } + + #[test] + fn double_negate() { + let ot = parse!("--a"); + assert!( + matches!(ot, OpTree::Unary(Unary { kind: UnaryKind::Neg, subj, ..}) + if matches!(&*subj.tree, OpTree::Unary(Unary { kind: UnaryKind::Neg, subj, .. }) + if matches!(&*subj.tree, OpTree::Term(a) if a.to_string() == "a"))) + ) + } + + #[test] + fn unary_negate_with_infix() { + let ot = parse!("-a * b"); + assert!(matches!(ot, OpTree::Infix( Infix { lhs, .. },) + if matches!(&*lhs.tree, OpTree::Unary(Unary { kind: UnaryKind::Neg, subj, .. }) + if matches!(&*subj.tree, OpTree::Term(a) if a.to_string() == "a")))) + } + + #[test] + fn dot_product() { + let ot = parse!("a .* b"); + assert!( + matches!(ot, OpTree::Infix( Infix { lhs, rhs, kind: InfixKind::LinComb } ) if + matches!(&*lhs.tree, OpTree::Term(a) if a.to_string() == "a") && + matches!(&*rhs.tree, OpTree::Term(b) if b.to_string() == "b") + ) + ) + } + + #[test] + fn callmethod() { + let ot = parse!("term.method(other, stuff)"); + assert!( + matches!(ot, OpTree::Term(call) if call.to_string() == "term . method (other , stuff)") + ); + let ot = parse!("term.method(other, stuff).another()"); + assert!( + matches!(ot, OpTree::Term(call) if call.to_string() == "term . method (other , stuff) . another ()") + ); + } + + #[test] + fn property() { + let ot = parse!("term.property"); + assert!(matches!(ot, OpTree::Term(call) if call.to_string() == "term . property")); + let ot = parse!("term.property.another"); + assert!( + matches!(ot, OpTree::Term(call) if call.to_string() == "term . property . another") + ); + } + + #[test] + fn tuple_index() { + let ot = parse!("term.0"); + assert!(matches!(ot, OpTree::Term(call) if call.to_string().replace(' ', "") == "term.0")); + let ot = parse!("term.1.2.3"); + assert!( + matches!(ot, OpTree::Term(call) if call.to_string().replace(' ', "") == "term.1.2.3") + ); + } + + #[test] + fn array_index() { + let ot = parse!("term[1]"); + assert!(matches!(ot, OpTree::Term(call) if call.to_string().replace(' ', "") == "term[1]")); + let ot = parse!("term[1..10]"); + assert!( + matches!(ot, OpTree::Term(call) if call.to_string().replace(' ', "") == "term[1..10]") + ); + } + + #[test] + fn lots_of_junk_added() { + let ot = parse!("term(arg1, arg2)[1].7.a_method()[2] + what.a.long.1[6].tail(of, things)"); + assert!( + matches!(ot, OpTree::Infix( Infix { lhs, rhs, kind: InfixKind::Add }) if + matches!(&*lhs.tree, OpTree::Term(call) if call.to_string().replace(' ', "") == "term(arg1,arg2)[1].7.a_method()[2]") && + matches!(&*rhs.tree, OpTree::Term(call) if call.to_string().replace(' ', "") == "what.a.long.1[6].tail(of,things)")) + ); + } + + #[test] + fn int_lit() { + let ot = parse!("1"); + assert!(matches!(ot, OpTree::LitInt(1u32))); + } +} diff --git a/ecdsa_fun/src/adaptor/mod.rs b/ecdsa_fun/src/adaptor/mod.rs index 4b10c606..91c6bd28 100644 --- a/ecdsa_fun/src/adaptor/mod.rs +++ b/ecdsa_fun/src/adaptor/mod.rs @@ -166,7 +166,7 @@ impl, NG> Adaptor { // will also be uniform. .expect("computationally unreachable"); - let s_hat = s!({ r.invert() } * (m + R_x * x)) + let s_hat = s!(r.invert() * (m + R_x * x)) .public() .non_zero() .expect("computationally unreachable"); @@ -286,7 +286,7 @@ impl, NG> Adaptor { return None; } let s = &signature.s; - let y = s!({ s.invert() } * s_hat); + let y = s!(s.invert() * s_hat); let Y = g!(y * G); if Y == *encryption_key { diff --git a/ecdsa_fun/src/lib.rs b/ecdsa_fun/src/lib.rs index a611d3cd..2129ae1c 100755 --- a/ecdsa_fun/src/lib.rs +++ b/ecdsa_fun/src/lib.rs @@ -177,17 +177,18 @@ impl ECDSA { .non_zero() .expect("computationally unreachable"); - let mut s = s!({ r.invert() } * (m + R_x * x)) + let mut s = s!((m + R_x * x) / r) // Given R_x is determined by x and m through a hash, reaching // (m + R_x * x) = 0 is intractable. .non_zero() - .expect("computationally unreachable"); + .expect("computationally unreachable") + .public(); // s values must be low (less than half group order), otherwise signatures // would be malleable i.e. (R,s) and (R,-s) would both be valid signatures. s.conditional_negate(s.is_high()); - Signature { R_x, s: s.public() } + Signature { R_x, s } } } diff --git a/ecdsa_fun/tests/adaptor_test_vectors.rs b/ecdsa_fun/tests/adaptor_test_vectors.rs index 2d06d5a9..de42477f 100644 --- a/ecdsa_fun/tests/adaptor_test_vectors.rs +++ b/ecdsa_fun/tests/adaptor_test_vectors.rs @@ -102,5 +102,5 @@ fn run_test_vector( let decryption_key = ecdsa_adaptor.recover_decryption_key(&t.encryption_key, &t.signature, &t.adaptor_sig); - decryption_key == Some(t.decryption_key.clone()) + decryption_key == Some(t.decryption_key) } diff --git a/schnorr_fun/benches/bench_schnorr.rs b/schnorr_fun/benches/bench_schnorr.rs index efb5a231..ddc9e4d1 100755 --- a/schnorr_fun/benches/bench_schnorr.rs +++ b/schnorr_fun/benches/bench_schnorr.rs @@ -18,7 +18,7 @@ lazy_static::lazy_static! { fn sign_schnorr(c: &mut Criterion) { let mut group = c.benchmark_group("schnorr_sign"); { - let keypair = schnorr.new_keypair(SK.clone()); + let keypair = schnorr.new_keypair(*SK); group.bench_function("fun::schnorr_sign", |b| { b.iter(|| schnorr.sign(&keypair, Message::::raw(MESSAGE))) }); @@ -27,7 +27,7 @@ fn sign_schnorr(c: &mut Criterion) { { use secp256k1::{KeyPair, Message, Secp256k1}; let secp = Secp256k1::new(); - let kp = KeyPair::from_secret_key(&secp, &SK.clone().into()); + let kp = KeyPair::from_secret_key(&secp, &(*SK).into()); let msg = Message::from_slice(&MESSAGE[..]).unwrap(); group.bench_function("secp::schnorrsig_sign_no_aux_rand", |b| { b.iter(|| { @@ -39,7 +39,7 @@ fn sign_schnorr(c: &mut Criterion) { fn verify_schnorr(c: &mut Criterion) { let mut group = c.benchmark_group("schnorr_verify"); - let keypair = schnorr.new_keypair(SK.clone()); + let keypair = schnorr.new_keypair(*SK); { let message = Message::::raw(MESSAGE); let sig = schnorr.sign(&keypair, message); @@ -59,7 +59,7 @@ fn verify_schnorr(c: &mut Criterion) { { use secp256k1::{KeyPair, Message, Secp256k1, XOnlyPublicKey}; let secp = Secp256k1::new(); - let kp = KeyPair::from_secret_key(&secp, &SK.clone().into()); + let kp = KeyPair::from_secret_key(&secp, &(*SK).into()); let pk = XOnlyPublicKey::from_keypair(&kp).0; let msg = Message::from_slice(&MESSAGE[..]).unwrap(); let sig = secp.sign_schnorr_no_aux_rand(&msg, &kp); diff --git a/schnorr_fun/src/frost.rs b/schnorr_fun/src/frost.rs index aaee981a..ec45267f 100644 --- a/schnorr_fun/src/frost.rs +++ b/schnorr_fun/src/frost.rs @@ -524,7 +524,7 @@ impl + Clone, NG: NonceGen> Frost { scalar_poly: &[Scalar], message: Message, ) -> Signature { - let key_pair = self.schnorr.new_keypair(scalar_poly[0].clone()); + let key_pair = self.schnorr.new_keypair(scalar_poly[0]); self.schnorr.sign(&key_pair, message) } @@ -848,10 +848,7 @@ impl + Clone, NG> Frost { let agg_nonce = nonce_map .iter() .fold([Point::zero(); 2], |acc, (_, nonce)| { - [ - g!({ acc[0] } + { nonce.0[0] }), - g!({ acc[1] } + { nonce.0[1] }), - ] + [g!(acc[0] + nonce.0[0]), g!(acc[1] + nonce.0[1])] }); let agg_nonce = [agg_nonce[0].normalize(), agg_nonce[1].normalize()]; @@ -864,12 +861,11 @@ impl + Clone, NG> Frost { .add(frost_key.public_key()) .add(message), ); - let (agg_nonce, nonces_need_negation) = - g!({ agg_nonce[0] } + binding_coeff * { agg_nonce[1] }) - .normalize() - .non_zero() - .unwrap_or(Point::generator()) - .into_point_with_even_y(); + let (agg_nonce, nonces_need_negation) = g!(agg_nonce[0] + binding_coeff * agg_nonce[1]) + .normalize() + .non_zero() + .unwrap_or(Point::generator()) + .into_point_with_even_y(); let challenge = self .schnorr @@ -1099,23 +1095,21 @@ fn scalar_poly_eval( poly: &[Scalar], x: Scalar, ) -> Scalar { - poly.iter() - .fold((s!(0), s!(1).mark_zero()), |(eval, xpow), coeff| { - (s!(eval + xpow * coeff), s!(xpow * x)) - }) - .0 + s!(powers(x) .* poly) } fn point_poly_eval( poly: &[Point], x: Scalar, ) -> Point { - let xpows = core::iter::successors(Some(s!(1).public().mark_zero()), |xpow| { - Some(s!(xpow * x).public()) + g!(powers(x) .* poly) +} + +/// Returns an iterator of 1, x, x², x³ ... +fn powers(x: Scalar) -> impl Iterator> { + core::iter::successors(Some(Scalar::one().mark_zero_choice::()), move |xpow| { + Some(s!(xpow * x).set_secrecy()) }) - .take(poly.len()) - .collect::>(); - secp256kfun::op::lincomb(&xpows, poly.iter()) } #[cfg(test)] diff --git a/schnorr_fun/src/musig.rs b/schnorr_fun/src/musig.rs index b84d80c4..0aa3fefe 100644 --- a/schnorr_fun/src/musig.rs +++ b/schnorr_fun/src/musig.rs @@ -312,7 +312,7 @@ impl + Clone, NG> MuSig { }) .collect::>(); - let agg_key = crate::fun::op::lincomb(coefs.iter(), keys.iter()) + let agg_key = g!(&coefs .* &keys) .non_zero().expect("computationally unreachable: linear combination of hash randomised points cannot add to zero"); AggKey { @@ -531,13 +531,10 @@ impl + Clone, NG> MuSig { ) { let mut Rs = nonces; let agg_Rs = Rs.iter().fold([Point::zero(); 2], |acc, nonce| { - [ - g!({ acc[0] } + { nonce.0[0] }), - g!({ acc[1] } + { nonce.0[1] }), - ] + [g!(acc[0] + nonce.0[0]), g!(acc[1] + nonce.0[1])] }); let agg_Rs = Nonce::([ - g!({ agg_Rs[0] } + encryption_key).normalize(), + g!(agg_Rs[0] + encryption_key).normalize(), agg_Rs[1].normalize(), ]); @@ -552,7 +549,7 @@ impl + Clone, NG> MuSig { .public() .mark_zero(); - let (R, r_needs_negation) = g!({ agg_Rs.0[0] } + b * { agg_Rs.0[1] }) + let (R, r_needs_negation) = g!(agg_Rs.0[0] + b * agg_Rs.0[1]) .normalize() .non_zero() .unwrap_or(Point::generator()) diff --git a/schnorr_fun/tests/against_c_lib.rs b/schnorr_fun/tests/against_c_lib.rs index b3fb1218..d40a93f7 100644 --- a/schnorr_fun/tests/against_c_lib.rs +++ b/schnorr_fun/tests/against_c_lib.rs @@ -63,7 +63,7 @@ proptest! { msg in any::<[u8;32]>(), ) { let secp = &*SECP; - let keypair = secp256k1::KeyPair::from_secret_key(secp, &key.clone().into()); + let keypair = secp256k1::KeyPair::from_secret_key(secp, &key.into()); let secp_msg = secp256k1::Message::from_slice(&msg).unwrap(); let sig = secp.sign_schnorr_no_aux_rand(&secp_msg, &keypair); let schnorr = Schnorr::::default(); diff --git a/schnorr_fun/tests/musig_sign_verify.rs b/schnorr_fun/tests/musig_sign_verify.rs index a2c8b433..17827c62 100644 --- a/schnorr_fun/tests/musig_sign_verify.rs +++ b/schnorr_fun/tests/musig_sign_verify.rs @@ -90,7 +90,7 @@ pub struct TestCase { fn musig_sign_verify() { let test_cases = serde_json::from_str::(TEST_JSON).unwrap(); let musig = musig::new_without_nonce_generation::(); - let keypair = musig.new_keypair(test_cases.sk.clone()); + let keypair = musig.new_keypair(test_cases.sk); for test_case in &test_cases.valid_test_cases { let pubkeys = test_case @@ -119,7 +119,7 @@ fn musig_sign_verify() { .unwrap() .nonce, ); - assert_eq!(partial_sig, test_case.expected.clone().unwrap()); + assert_eq!(partial_sig, test_case.expected.unwrap()); assert!(musig.verify_partial_signature( &agg_key, &session, diff --git a/secp256kfun/Cargo.toml b/secp256kfun/Cargo.toml index 626c7cdd..31df8c4b 100644 --- a/secp256kfun/Cargo.toml +++ b/secp256kfun/Cargo.toml @@ -18,6 +18,7 @@ keywords = ["bitcoin", "secp256k1"] digest = { version = "0.10", default-features = false } subtle = { package = "subtle-ng", version = "2", default-features = false } rand_core = { version = "0.6", default-features = false } +secp256kfun_arithmetic_macros = { version = "0.9.0", path = "../arithmetic_macros" } # optional serde = { version = "1.0", optional = true, default-features = false, features = ["derive"] } diff --git a/secp256kfun/benches/bench_ecmult.rs b/secp256kfun/benches/bench_ecmult.rs index 3de3fd0b..9e6fd57e 100755 --- a/secp256kfun/benches/bench_ecmult.rs +++ b/secp256kfun/benches/bench_ecmult.rs @@ -153,7 +153,7 @@ fn multi_mul(c: &mut Criterion) { ], ) }, - |(scalars, points)| op::lincomb(scalars.iter(), points.iter()), + |(scalars, points)| op::point_scalar_dot_product(scalars.iter(), points.iter()), BatchSize::SmallInput, ) }); diff --git a/secp256kfun/src/backend/k256_impl.rs b/secp256kfun/src/backend/k256_impl.rs index 5016c4b7..bdfa0907 100644 --- a/secp256kfun/src/backend/k256_impl.rs +++ b/secp256kfun/src/backend/k256_impl.rs @@ -202,11 +202,17 @@ impl TimeSensitive for ConstantTime { } #[cfg(feature = "alloc")] - fn lincomb_iter<'a, 'b, A: Iterator, B: Iterator>( + #[inline(always)] + fn lincomb_iter< + A: Iterator, + B: Iterator, + AT: AsRef, + BT: AsRef, + >( points: A, scalars: B, ) -> Point { - mul::lincomb_iter(points, scalars) + mul::lincomb_iter(points.map(|p| *p.as_ref()), scalars) } } @@ -327,7 +333,12 @@ impl TimeSensitive for VariableTime { } #[cfg(feature = "alloc")] - fn lincomb_iter<'a, 'b, A: Iterator, B: Iterator>( + fn lincomb_iter< + A: Iterator, + B: Iterator, + AT: AsRef, + BT: AsRef, + >( points: A, scalars: B, ) -> Point { diff --git a/secp256kfun/src/backend/mod.rs b/secp256kfun/src/backend/mod.rs index 2c54fa8d..af0f4ab4 100644 --- a/secp256kfun/src/backend/mod.rs +++ b/secp256kfun/src/backend/mod.rs @@ -1,5 +1,6 @@ //! These traits are for accounting for what methods each backend actually needs. mod k256_impl; + pub use k256_impl::*; pub trait BackendScalar: Sized { @@ -59,12 +60,30 @@ pub trait TimeSensitive { fn scalar_mul(lhs: &Scalar, rhs: &Scalar) -> Scalar; fn scalar_invert(scalar: &Scalar) -> Scalar; fn scalar_mul_basepoint(scalar: &Scalar, base: &BasePoint) -> Point; - fn lincomb_iter<'a, 'b, A: Iterator, B: Iterator>( + fn lincomb_iter< + A: Iterator, + B: Iterator, + AT: AsRef, + BT: AsRef, + >( points: A, scalars: B, ) -> Point { points.zip(scalars).fold(Point::zero(), |acc, (X, k)| { - Self::point_add_point(&acc, &Self::scalar_mul_point(k, X)) + Self::point_add_point(&acc, &Self::scalar_mul_point(k.as_ref(), X.as_ref())) + }) + } + fn scalar_lincomb_iter< + A: Iterator, + B: Iterator, + AT: AsRef, + BT: AsRef, + >( + scalars1: A, + scalars2: B, + ) -> Scalar { + scalars1.zip(scalars2).fold(Scalar::zero(), |acc, (a, b)| { + Self::scalar_add(&acc, &Self::scalar_mul(a.as_ref(), b.as_ref())) }) } } diff --git a/secp256kfun/src/keypair.rs b/secp256kfun/src/keypair.rs index 66eb10ad..d7af6b88 100644 --- a/secp256kfun/src/keypair.rs +++ b/secp256kfun/src/keypair.rs @@ -77,8 +77,8 @@ impl KeyPair { /// &original_secret_key == keypair.secret_key() /// || &-original_secret_key == keypair.secret_key() /// ); - /// assert!(g!({ keypair.secret_key() } * G).normalize().is_y_even()); - /// assert_eq!(g!({ keypair.secret_key() } * G), keypair.public_key()); + /// assert!(g!(keypair.secret_key() * G).normalize().is_y_even()); + /// assert_eq!(g!(keypair.secret_key() * G), keypair.public_key()); /// ``` /// /// [`Point`]: crate::Point diff --git a/secp256kfun/src/lib.rs b/secp256kfun/src/lib.rs index fe64f874..356d09af 100755 --- a/secp256kfun/src/lib.rs +++ b/secp256kfun/src/lib.rs @@ -57,6 +57,10 @@ pub use serde; #[cfg(feature = "bincode")] pub use bincode; +#[doc(hidden)] +/// these are helpers so we hide them. Actual g! macro is defined in macros.rs +pub use secp256kfun_arithmetic_macros as arithmetic_macros; + mod libsecp_compat; #[cfg(any(feature = "proptest", test))] mod proptest_impls; diff --git a/secp256kfun/src/macros.rs b/secp256kfun/src/macros.rs index fc2b2570..14243e84 100644 --- a/secp256kfun/src/macros.rs +++ b/secp256kfun/src/macros.rs @@ -1,221 +1,42 @@ -#[doc(hidden)] -#[macro_export] -macro_rules! _s { - (@dot [$($a:tt)*] [$($aa:ident).+] . $attr:ident $($t:tt)*) => { - $crate::_s!(@dot [$($a)*] [$($aa).+.$attr] $($t)*) - }; - (@dot [$($a:tt)*] [$($aa:ident).+] $($t:tt)*) => { - // no more dots to process to join them all together - $crate::_s!(@next [{$($aa).+.borrow()} $($a)*] $($t)*) - }; - (@scalar [$($a:tt)*] & $($t:tt)+) => { - core::compile_error!("Do not use ‘&’ in s!(...) expression"); - }; - (@scalar [$($a:tt)*] - $($t:tt)+) => { - $crate::_s!(@scalar [neg $($a)*] $($t)+) - }; - (@scalar [$($a:tt)*] $scalar:ident $($t:tt)*) => { - $crate::_s!(@dot [$($a)*] [$scalar] $($t)*) - }; - (@scalar [$($a:tt)*] 0 $($t:tt)*) => { - $crate::_s!(@next [{$crate::Scalar::<$crate::marker::Secret,_>::zero()} $($a)*] $($t)*) - }; - (@scalar [$($a:tt)*] $num:literal $($t:tt)*) => { - $crate::_s!(@next [{{ - // hack to check at compile time the thing is non-zero - let _ = [(); (($num as u32).count_ones() as usize) - 1]; - $crate::Scalar::<$crate::marker::Secret, $crate::marker::NonZero>::from_non_zero_u32( - unsafe { core::num::NonZeroU32::new_unchecked($num) }, - ) - }} $($a)*] $($t)*) - }; - (@scalar [$($a:tt)*] $block:block $($t:tt)*) => { - $crate::_s!(@next [$block $($a)*] $($t)*) - }; - (@scalar [$($a:tt)*] ($($subexpr:tt)+) $($t:tt)*) => { - $crate::_s!(@next [{$crate::_s!(@scalar [] $($subexpr)+)} $($a)*] $($t)*) - }; - - (@next [$stack0:block neg $($a:tt)*] $($t:tt)*) => { - $crate::_s!(@next [{core::ops::Neg::neg($stack0)} $($a)*] $($t)*) - }; - - (@next [$stack0:block $stack1:block mul $($a:tt)*] $($t:tt)*) => { - $crate::_s!(@next [{$crate::op::scalar_mul($stack1.borrow(), $stack0.borrow())} $($a)*] $($t)*) - }; - - (@next [$stack0:block $($a:tt)*] * $($t:tt)+) => { - $crate::_s!(@scalar [$stack0 mul $($a)*] $($t)*) - }; - - (@next [$stack0:block $stack1:block sub $($a:tt)*] $($t:tt)*) => { - $crate::_s!(@next [{$crate::op::scalar_sub($stack1.borrow(), $stack0.borrow())} $($a)*] $($t)*) - }; - - (@next [$stack0:block $stack1:block add $($a:tt)*] $($t:tt)*) => { - $crate::_s!(@next [{$crate::op::scalar_add($stack1.borrow(), $stack0.borrow())} $($a)*] $($t)*) - }; - - (@next [$stack0:block $($a:tt)*] - $($t:tt)+) => { - $crate::_s!(@scalar [$stack0 sub $($a)*] $($t)+) - }; - - (@next [$stack0:block $($a:tt)*] + $($t:tt)+) => { - $crate::_s!(@scalar [$stack0 add $($a)*] $($t)+) - }; - - (@next [$scalar:block]) => { - $scalar - }; - - (@next [$scalar:block stringify]) => { - stringify!($scalar) - }; -} - /// Scalar expression macro. +/// +/// Like [`g!`] except that the output of the expression is a [`Scalar`] rather than a [`Point`]. +/// +/// [`Scalar`]: crate::Scalar +/// [`Point`]: crate::Point +/// [`g!`]: crate::g #[macro_export] macro_rules! s { - (DEBUG $($t:tt)*) => {{ - #[allow(unused_imports)] - use core::borrow::Borrow; - $crate::_s!(@scalar [stringify] $($t)*) - }}; ($($t:tt)*) => {{ - #[allow(unused_imports)] - use core::borrow::Borrow; - $crate::_s!(@scalar [] $($t)*) + $crate::arithmetic_macros::gen_s!($crate $($t)*) }} } -#[doc(hidden)] -#[macro_export] -macro_rules! _g { - (@scalar [$($a:tt)*] & $($t:tt)+) => { - core::compile_error!("Do not use ‘&’ in g!(...) expression"); - }; - (@scalar [$($a:tt)*] - $($t:tt)+) => { - $crate::_g!(@scalar [neg $($a)*] $($t)+) - }; - (@scalar [$($a:tt)*] ($($expr:tt)+) * $($t:tt)+) => { - $crate::_g!(@point [s {$crate::_s!(@scalar [] $($expr)+)} $($a)*] $($t)+) - }; - (@scalar [$($a:tt)*] $ident:ident $($t:tt)*) => { - // We've got an identifier "foo" go and try to match foo.bar - // we don't know if this is a scalar yet. - $crate::_g!(@dot [$($a)*] [$ident] $($t)*) - }; - (@scalar [$($a:tt)*] $block:block * $($t:tt)*) => { - $crate::_g!(@point [s $block $($a)*] $($t)*) - }; - (@scalar [$($a:tt)*] 0 * $($t:tt)+) => { - $crate::_g!(@point [s {$crate::Scalar::<$crate::marker::Secret,_>::zero()} $($a)*] $($t)+) - }; - (@scalar [$($a:tt)*] $num:literal * $($t:tt)+) => { - $crate::_g!(@point [s {{ - // hack to check at compile time the thing is non-zero - let _ = [(); (($num as u32).count_ones() as usize) - 1]; - $crate::Scalar::<$crate::marker::Secret, $crate::marker::NonZero>::from_non_zero_u32( - unsafe { core::num::NonZeroU32::new_unchecked($num) }, - ) - }} $($a)*] $($t)+) - }; - (@scalar [$($a:tt)*] $($t:tt)+) => { - // failed to find scalar look for point instead - $crate::_g!(@point [$($a)*] $($t)+) - }; - (@dot [$($a:tt)*] [$($aa:ident).+] . $attr:ident $($t:tt)*) => { - $crate::_g!(@dot [$($a)*] [$($aa).+.$attr] $($t)*) - }; - (@dot [$($a:tt)*] [$($aa:ident).+] * $($t:tt)*) => { - // no more dots to process and we seem to have a scalar. - // Join them together and look for a point. - $crate::_g!(@point [s {$($aa).+.borrow()} $($a)*] $($t)*) - }; - (@dot [$($a:tt)*] [$($aa:ident).+] $($t:tt)*) => { - // no more dots to process and it looks like this was a point - // so go onto the next operator - $crate::_g!(@next [{$($aa).+.borrow()} $($a)*] $($t)*) - }; - (@point [$($a:tt)*] $point:ident $($t:tt)*) => { - $crate::_g!(@dot [$($a)*] [$point] $($t)*) - }; - (@point [$($a:tt)*] $block:block $($t:tt)*) => { - $crate::_g!(@next [$block $($a)*] $($t)*) - }; - (@point [$($a:tt)*] ($($expr:tt)+) $($t:tt)*) => { - $crate::_g!(@next [{ $crate::_g!(@scalar [] $($expr)+) } $($a)*] $($t)*) - }; - (@next [$point0:block s $scalar0:block neg $($a:tt)*] $($t:tt)*) => { - $crate::_g!(@next [$point0 s {core::ops::Neg::neg($scalar0.borrow())} $($a)*] $($t)*) - }; - (@next [$point0:block neg $($a:tt)*] $($t:tt)*) => { - $crate::_g!(@next [{core::ops::Neg::neg($point0)} $($a)*] $($t)*) - }; - (@next [$point0:block s $scalar0:block $point1:block s $scalar1:block add $($a:tt)*] $($t:tt)*) => { - - $crate::_g!(@next [{ - $crate::op::double_mul( - $scalar0.borrow(), - $point0.borrow(), - $scalar1.borrow(), - $point1.borrow() - )} $($a)*] $($t)*) - }; - (@next [$point0:block s $scalar0:block $point1:block s $scalar1:block sub $($a:tt)*] $($t:tt)*) => { - $crate::_g!(@next [{ - $crate::op::double_mul( - &core::ops::Neg::neg($scalar0), - $point0.borrow(), - $scalar1.borrow(), - $point1.borrow() - )} $($a)*] $($t)*) - }; - (@next [$point0:block $(s $scalar0:block)? $point1:block $(s $scalar1:block)? add $($a:tt)*] $($t:tt)*) => { - $crate::_g!(@next [ - {$crate::op::point_add( - $crate::_g!(@next [$point1 $(s $scalar1)?]).borrow(), - $crate::_g!(@next [$point0 $(s $scalar0)?]).borrow() - )} $($a)*] $($t)*) - }; - (@next [$point0:block $(s $scalar0:block)? $point1:block $(s $scalar1:block)? sub $($a:tt)*] $($t:tt)*) => { - $crate::_g!(@next [ - {$crate::op::point_sub( - $crate::_g!(@next [$point1 $(s $scalar1)?]).borrow(), - $crate::_g!(@next [$point0 $(s $scalar0)?]).borrow() - )} - $($a)*] $($t)*) - }; - (@next [$point0:block s $scalar0:block $($a:tt)*] + $($t:tt)*) => { - $crate::_g!(@scalar [$point0 s $scalar0 add $($a)*] $($t)*) - }; - (@next [$point0:block $($a:tt)*] + $($t:tt)+) => { - $crate::_g!(@scalar [$point0 add $($a)*] $($t)+) - }; - (@next [$point0:block s $scalar0:block $($a:tt)*] - $($t:tt)+) => { - $crate::_g!(@scalar [$point0 s $scalar0 sub $($a)*] $($t)+) - }; - (@next [$point0:block $($a:tt)*] - $($t:tt)+) => { - $crate::_g!(@scalar [$point0 sub $($a)*] $($t)+) - }; - (@next [$point0:block s $scalar0:block]) => { - $crate::op::scalar_mul_point($scalar0.borrow(), $point0.borrow()) - }; - (@next [$point0:block]) => { - $point0 - } -} - /// Group operation expression macro. /// -/// The `g!` macro lets you express a set of scalar multiplications and group -/// additions/substraction. This compiles down to operations from the [`op`] -/// module. Apart from being far more readable, the idea is that `g!` will (or -/// may in the future) compile to more efficient operations than if you were to -/// manually call the functions from `op` yourself. +/// The `g!` macro lets you express scalar multiplications and group operations conveniently +/// following standard [order of operations]. This compiles down to operations from the [`op`] +/// module. Apart from being far more readable, the idea is that `g!` will (or may in the future) +/// compile to more efficient operations than if you were to manually call the functions from [`op`] +/// yourself. +/// +/// Note you can but often don't need to put a `&` in front of the terms in the expression. /// -/// As a bonus, you don't need to put reference `&` makers on terms in `g!` this -/// is done automatically if necessary. +/// # Syntax and operations +/// +/// The expression supports the following operations: +/// +/// - ` * ` multiplies the `point` by `scalar` +/// - ` + ` adds two points +/// - ` - ` subtracts one point from another +/// - ` .* ` does a [dot product](https://en.wikipedia.org/wiki/Dot_product) +/// between a list of points and scalars. If one list is shorter than the other then the excess +/// points or scalars will be multiplied by 0. See [`op::point_scalar_dot_product`]. +/// +/// The terms of the expression can be any variable followed by simple method calls, attribute +/// access etc. If your term involves more expressions (anything involving specifying types using +/// `::`) then you can use `{..}` to surround arbitrary expressions. You can also use `(..)` to +/// group arithmetic expressions to override the usual operation order. /// /// # Examples /// @@ -234,11 +55,11 @@ macro_rules! _g { /// let H = Point::random(&mut rand::thread_rng()); /// let minus = g!(x * G - y * H); /// let plus = g!(x * G + y * H); -/// // note the parenthesis around the scalar sub expression -/// assert_eq!(g!(plus + minus), g!((2 * x) * G)); +/// assert_eq!(g!(plus + minus), g!(2 * x * G)); // this will do 2 * x first +/// assert_eq!(g!(42 * (G + H)), g!((42 * G + 42 * H))); /// ``` /// -/// You may access attributes: +/// You may access attributes and call methods: /// /// ``` /// # use secp256kfun::{g, Point, Scalar, G}; @@ -253,28 +74,27 @@ macro_rules! _g { /// }; /// /// let result = g!(mul.scalar * mul.point); +/// assert_eq!(g!(mul.scalar.invert() * result), mul.point); /// ``` /// /// You can put an arbitrary expressions inside `{...}` /// /// ``` /// # use secp256kfun::{g, Point, Scalar, G}; -/// let x = Scalar::random(&mut rand::thread_rng()); -/// let Xinv = g!({ x.invert() } * G); -/// assert_eq!(g!(x * Xinv), *G); +/// let random_point = g!({ Scalar::random(&mut rand::thread_rng()) } * G); /// ``` /// /// [`double_mul`]: crate::op::double_mul /// [`G`]: crate::G /// [`Point`]: crate::Point /// [`op`]: crate::op - +/// [order of operations]: https://en.wikipedia.org/wiki/Order_of_operations +/// [`op::point_scalar_dot_product`]: crate::op::point_scalar_dot_product #[macro_export] macro_rules! g { - ($($t:tt)+) => {{ - #[allow(unused_imports)] - use core::borrow::Borrow; - $crate::_g!(@scalar [] $($t)+) }}; + ($($t:tt)*) => {{ + $crate::arithmetic_macros::gen_g!($crate $($t)*) + }} } /// Macro to make nonce derivation clear and explicit. @@ -315,6 +135,7 @@ macro_rules! derive_nonce { public => [$($public:expr),+]$(,)? ) => {{ use $crate::hash::HashAdd; + #[allow(unused_imports)] use core::borrow::Borrow; use $crate::nonce::NonceGen; Scalar::from_hash( diff --git a/secp256kfun/src/marker/zero_choice.rs b/secp256kfun/src/marker/zero_choice.rs index 0829d890..e001fc15 100644 --- a/secp256kfun/src/marker/zero_choice.rs +++ b/secp256kfun/src/marker/zero_choice.rs @@ -22,6 +22,7 @@ pub trait ZeroChoice: + Copy + DecideZero + DecideZero + + DecideZero + core::hash::Hash + Ord + PartialOrd diff --git a/secp256kfun/src/op.rs b/secp256kfun/src/op.rs index 3bcd2cdf..74f05afc 100644 --- a/secp256kfun/src/op.rs +++ b/secp256kfun/src/op.rs @@ -6,40 +6,29 @@ //! macros which compile your expressions into (potentially more efficient) //! calls to these functions. //! -//! Most of the functions here are [`specialized`] so the compiler may be able to -//! choose a faster algorithm depending on the arguments. For example scalar -//! multiplications are faster points marked `BasePoint` like [`G`], so in the -//! following snippet computing `X1` will be computed faster than `X2` even -//! though the same function is being called. -//! ``` -//! use secp256kfun::{marker::*, op, Scalar, G}; -//! let x = Scalar::random(&mut rand::thread_rng()); -//! let X1 = op::scalar_mul_point(&x, G); // fast -//! let H = &G.normalize(); // scrub `BasePoint` marker -//! let X2 = op::scalar_mul_point(&x, &H); // slow -//! assert_eq!(X1, X2); -//! ``` +//! Some of the functions here use the type parameters to try and optimize the operation they +//! perform. +//! //! [`Points`]: crate::Point //! [`Scalars`]: crate::Scalar -//! [`specialized`]: https://github.com/rust-lang/rust/issues/31844 -//! [`G`]: crate::G #[allow(unused_imports)] use crate::{ backend::{self, ConstantTime, TimeSensitive, VariableTime}, marker::*, Point, Scalar, }; +use core::borrow::Borrow; /// Computes `x * A + y * B` more efficiently than calling [`scalar_mul_point`] twice. #[inline(always)] pub fn double_mul( - x: &Scalar, - A: &Point, - y: &Scalar, - B: &Point, + x: impl Borrow>, + A: impl Borrow>, + y: impl Borrow>, + B: impl Borrow>, ) -> Point { Point::from_inner( - ConstantTime::point_double_mul(&x.0, &A.0, &y.0, &B.0), + ConstantTime::point_double_mul(&x.borrow().0, &A.borrow().0, &y.borrow().0, &B.borrow().0), NonNormal, ) } @@ -47,54 +36,69 @@ pub fn double_mul( /// Computes multiplies the point `P` by the scalar `x`. #[inline(always)] pub fn scalar_mul_point( - x: &Scalar, - P: &Point, + x: impl Borrow>, + P: impl Borrow>, ) -> Point where Z1: DecideZero, { - Point::from_inner(ConstantTime::scalar_mul_point(&x.0, &P.0), NonNormal) + Point::from_inner( + ConstantTime::scalar_mul_point(&x.borrow().0, &P.borrow().0), + NonNormal, + ) } /// Multiplies two scalars together (modulo the curve order) #[inline(always)] -pub fn scalar_mul(x: &Scalar, y: &Scalar) -> Scalar +pub fn scalar_mul( + x: impl Borrow>, + y: impl Borrow>, +) -> Scalar where Z1: DecideZero, { - Scalar::from_inner(ConstantTime::scalar_mul(&x.0, &y.0)) + Scalar::from_inner(ConstantTime::scalar_mul(&x.borrow().0, &y.borrow().0)) } /// Adds two scalars together (modulo the curve order) #[inline(always)] -pub fn scalar_add(x: &Scalar, y: &Scalar) -> Scalar { - Scalar::from_inner(ConstantTime::scalar_add(&x.0, &y.0)) +pub fn scalar_add( + x: impl Borrow>, + y: impl Borrow>, +) -> Scalar { + Scalar::from_inner(ConstantTime::scalar_add(&x.borrow().0, &y.borrow().0)) } /// Subtracts one scalar from another #[inline(always)] -pub fn scalar_sub(x: &Scalar, y: &Scalar) -> Scalar { - Scalar::from_inner(ConstantTime::scalar_sub(&x.0, &y.0)) +pub fn scalar_sub( + x: impl Borrow>, + y: impl Borrow>, +) -> Scalar { + Scalar::from_inner(ConstantTime::scalar_sub(&x.borrow().0, &y.borrow().0)) } /// Checks equality between two scalars #[inline(always)] -pub fn scalar_eq(x: &Scalar, y: &Scalar) -> bool { - ConstantTime::scalar_eq(&x.0, &y.0) +pub fn scalar_eq( + x: impl Borrow>, + y: impl Borrow>, +) -> bool { + ConstantTime::scalar_eq(&x.borrow().0, &y.borrow().0) } /// Negate a scalar #[inline(always)] -pub fn scalar_negate(x: &Scalar) -> Scalar { - let mut negated = x.0; +pub fn scalar_negate(x: impl Borrow>) -> Scalar { + let mut negated = x.borrow().0; ConstantTime::scalar_cond_negate(&mut negated, true); Scalar::from_inner(negated) } /// Invert a scalar #[inline(always)] -pub fn scalar_invert(x: &Scalar) -> Scalar { - Scalar::from_inner(ConstantTime::scalar_invert(&x.0)) +pub fn scalar_invert(x: impl Borrow>) -> Scalar { + Scalar::from_inner(ConstantTime::scalar_invert(&x.borrow().0)) } /// Conditionally negate a scalar @@ -118,19 +122,25 @@ pub fn scalar_is_zero(x: &Scalar) -> bool { /// Subtracts one point from another #[inline(always)] pub fn point_sub( - A: &Point, - B: &Point, + A: impl Borrow>, + B: impl Borrow>, ) -> Point { - Point::from_inner(ConstantTime::point_sub_point(&A.0, &B.0), NonNormal) + Point::from_inner( + ConstantTime::point_sub_point(&A.borrow().0, &B.borrow().0), + NonNormal, + ) } /// Adds two points together #[inline(always)] pub fn point_add( - A: &Point, - B: &Point, + A: impl Borrow>, + B: impl Borrow>, ) -> Point { - Point::from_inner(ConstantTime::point_add_point(&A.0, &B.0), NonNormal) + Point::from_inner( + ConstantTime::point_add_point(&A.borrow().0, &B.borrow().0), + NonNormal, + ) } /// Checks if two points are equal @@ -150,8 +160,10 @@ where /// Negate a point #[inline(always)] -pub fn point_negate(A: &Point) -> Point { - let mut A = A.0; +pub fn point_negate( + A: impl Borrow>, +) -> Point { + let mut A = A.borrow().0; ConstantTime::any_point_neg(&mut A); Point::from_inner(A, T::NegationType::default()) } @@ -159,10 +171,10 @@ pub fn point_negate(A: &Point) -> Point( - A: &Point, + A: impl Borrow>, cond: bool, ) -> Point { - let mut A = A.0; + let mut A = A.borrow().0; ConstantTime::any_point_conditional_negate(&mut A, cond); Point::from_inner(A, T::NegationType::default()) } @@ -179,8 +191,34 @@ where Point::from_inner(A.0, Normal) } +/// Does a [dot product](https://en.wikipedia.org/wiki/Dot_product) of points with scalars +/// +/// If one of the iterators is longer than the other then the excess points or scalars will be +/// multiplied by 0. +#[inline(always)] +pub fn point_scalar_dot_product< + T1, + S1, + Z1, + S2, + Z2, + I2: Borrow> + AsRef, + I1: Borrow> + AsRef, +>( + scalars: impl IntoIterator, + points: impl IntoIterator, +) -> Point { + Point::from_inner( + ConstantTime::lincomb_iter(points.into_iter(), scalars.into_iter()), + NonNormal, + ) +} + /// Does a linear combination of points +/// +/// ⚠ deprecated in favor of [`point_scalar_dot_product`] which has a more convienient API and name. #[inline(always)] +#[deprecated(since = "0.10.0", note = "use point_scalar_dot_product instead")] pub fn lincomb<'a, T1: 'a, S1: 'a, Z1: 'a, S2: 'a, Z2: 'a>( scalars: impl IntoIterator>, points: impl IntoIterator>, @@ -194,6 +232,27 @@ pub fn lincomb<'a, T1: 'a, S1: 'a, Z1: 'a, S2: 'a, Z2: 'a>( ) } +#[inline(always)] +/// Does a [dot product] between two iterators of scalars. +/// +/// If one of the iterators is longer than the other then the excess scalars will be multipled by 0. +pub fn scalar_dot_product< + S1, + Z1, + S2, + Z2, + I1: Borrow> + AsRef, + I2: Borrow> + AsRef, +>( + scalars1: impl IntoIterator, + scalars2: impl IntoIterator, +) -> Scalar { + Scalar::from_inner(ConstantTime::scalar_lincomb_iter( + scalars1.into_iter(), + scalars2.into_iter(), + )) +} + /// Check if a point has an even y-coordinate #[inline(always)] pub fn point_is_y_even(A: &Point) -> bool { @@ -242,8 +301,8 @@ mod test { C in any::() ) { use crate::op::*; - assert_eq!(lincomb([&a,&b,&c], [&A,&B,&C]), - point_add(&scalar_mul_point(&a, &A), &point_add(&scalar_mul_point(&b, &B), &scalar_mul_point(&c, &C)))) + assert_eq!(point_scalar_dot_product([&a,&b,&c], [&A,&B,&C]), + point_add(scalar_mul_point(a, A), point_add(scalar_mul_point(b, B), scalar_mul_point(c, C)))) } } } diff --git a/secp256kfun/src/point.rs b/secp256kfun/src/point.rs index 14dcd5f9..e057fc07 100644 --- a/secp256kfun/src/point.rs +++ b/secp256kfun/src/point.rs @@ -72,7 +72,13 @@ impl Clone for Point { } } -impl Copy for Point {} +impl AsRef for Point { + fn as_ref(&self) -> &backend::Point { + &self.0 + } +} + +impl Copy for Point {} impl Point { /// Samples a point uniformly from the group. @@ -222,7 +228,7 @@ impl Point { base: &Point, scalar: &mut Scalar, ) -> Self { - let point = crate::op::scalar_mul_point(scalar, base); + let point = crate::op::scalar_mul_point(*scalar, base); let (point, needs_negation) = point.into_point_with_even_y(); scalar.conditional_negate(needs_negation); point @@ -256,7 +262,7 @@ impl Point { where T: PointType, { - op::point_conditional_negate(&self.clone(), cond) + op::point_conditional_negate(*self, cond) } /// Set the [`Secrecy`] of the point. @@ -329,7 +335,7 @@ impl Point { impl core::ops::Neg for Point { type Output = Point; fn neg(self) -> Self::Output { - op::point_negate(&self) + op::point_negate(self) } } @@ -560,25 +566,25 @@ crate::impl_fromstr_deserialize! { impl AddAssign> for Point { fn add_assign(&mut self, rhs: Point) { - *self = crate::op::point_add(self, &rhs).set_secrecy::() + *self = crate::op::point_add(*self, &rhs).set_secrecy::() } } impl AddAssign<&Point> for Point { fn add_assign(&mut self, rhs: &Point) { - *self = crate::op::point_add(self, rhs).set_secrecy::() + *self = crate::op::point_add(*self, rhs).set_secrecy::() } } impl SubAssign<&Point> for Point { fn sub_assign(&mut self, rhs: &Point) { - *self = crate::op::point_sub(self, rhs).set_secrecy::() + *self = crate::op::point_sub(*self, rhs).set_secrecy::() } } impl SubAssign> for Point { fn sub_assign(&mut self, rhs: Point) { - *self = crate::op::point_sub(self, &rhs).set_secrecy::() + *self = crate::op::point_sub(*self, &rhs).set_secrecy::() } } @@ -835,7 +841,7 @@ mod test { let mut a = a_orig; let b = Point::random(&mut rand::thread_rng()); a += b; - assert_eq!(a, op::point_add(&a_orig, &b)); + assert_eq!(a, op::point_add(a_orig, b)); a -= b; assert_eq!(a, a_orig); } diff --git a/secp256kfun/src/scalar.rs b/secp256kfun/src/scalar.rs index 5f538fe2..a7d883a7 100644 --- a/secp256kfun/src/scalar.rs +++ b/secp256kfun/src/scalar.rs @@ -51,11 +51,17 @@ use rand_core::RngCore; /// [`ZeroChoice]: crate::marker::ZeroChoice pub struct Scalar(pub(crate) backend::Scalar, PhantomData<(Z, S)>); -impl Copy for Scalar {} +impl Copy for Scalar {} + +impl AsRef for Scalar { + fn as_ref(&self) -> &backend::Scalar { + &self.0 + } +} impl Clone for Scalar { fn clone(&self) -> Self { - Self(self.0, self.1) + *self } } @@ -143,6 +149,39 @@ impl Scalar { pub fn minus_one() -> Self { Self::from_inner(backend::BackendScalar::minus_one()) } + + /// Marks a scalar non-zero scalar as having the zero choice `Z` (rather than `NonZero`). + /// + /// Useful when writing code that preserves the zero choice of the caller. + /// + /// # Example + /// + /// ``` + /// use secp256kfun::{marker::*, s, Scalar}; + /// + /// /// Returns an iterator of 1, x, x², x³ ... + /// fn powers(x: Scalar) -> impl Iterator> { + /// core::iter::successors(Some(Scalar::one().mark_zero_choice::()), move |xpow| { + /// Some(s!(xpow * x).set_secrecy()) + /// }) + /// } + /// + /// assert_eq!(powers(s!(2)).take(4).collect::>(), vec![ + /// s!(1), + /// s!(2), + /// s!(4), + /// s!(8) + /// ]); + /// assert_eq!(powers(s!(0)).take(4).collect::>(), vec![ + /// s!(1).mark_zero(), + /// s!(0), + /// s!(0), + /// s!(0) + /// ]); + /// ``` + pub fn mark_zero_choice(self) -> Scalar { + Scalar::from_inner(self.0) + } } impl Scalar { @@ -320,7 +359,7 @@ impl core::ops::Neg for Scalar { type Output = Scalar; fn neg(self) -> Self::Output { - crate::op::scalar_negate(&self) + crate::op::scalar_negate(self) } } @@ -358,49 +397,49 @@ where impl AddAssign> for Scalar { fn add_assign(&mut self, rhs: Scalar) { - *self = crate::op::scalar_add(self, &rhs).set_secrecy::(); + *self = crate::op::scalar_add(*self, rhs).set_secrecy::(); } } impl AddAssign<&Scalar> for Scalar { fn add_assign(&mut self, rhs: &Scalar) { - *self = crate::op::scalar_add(self, rhs).set_secrecy::(); + *self = crate::op::scalar_add(*self, rhs).set_secrecy::(); } } impl SubAssign<&Scalar> for Scalar { fn sub_assign(&mut self, rhs: &Scalar) { - *self = crate::op::scalar_sub(self, rhs).set_secrecy::(); + *self = crate::op::scalar_sub(*self, rhs).set_secrecy::(); } } impl SubAssign> for Scalar { fn sub_assign(&mut self, rhs: Scalar) { - *self = crate::op::scalar_sub(self, &rhs).set_secrecy::(); + *self = crate::op::scalar_sub(*self, rhs).set_secrecy::(); } } impl MulAssign> for Scalar { fn mul_assign(&mut self, rhs: Scalar) { - *self = crate::op::scalar_mul(self, &rhs).set_secrecy::(); + *self = crate::op::scalar_mul(*self, rhs).set_secrecy::(); } } impl MulAssign<&Scalar> for Scalar { fn mul_assign(&mut self, rhs: &Scalar) { - *self = crate::op::scalar_mul(self, rhs).set_secrecy::(); + *self = crate::op::scalar_mul(*self, rhs).set_secrecy::(); } } impl MulAssign> for Scalar { fn mul_assign(&mut self, rhs: Scalar) { - *self = crate::op::scalar_mul(self, &rhs).set_secrecy::(); + *self = crate::op::scalar_mul(*self, rhs).set_secrecy::(); } } impl MulAssign<&Scalar> for Scalar { fn mul_assign(&mut self, rhs: &Scalar) { - *self = crate::op::scalar_mul(self, rhs).set_secrecy::(); + *self = crate::op::scalar_mul(*self, rhs).set_secrecy::(); } } @@ -506,7 +545,7 @@ mod test { -Scalar::::one() ); assert_eq!( - op::scalar_mul(&s!(3), &Scalar::::minus_one()), + op::scalar_mul(s!(3), Scalar::::minus_one()), -s!(3) ); } diff --git a/secp256kfun/src/slice.rs b/secp256kfun/src/slice.rs index 2f25753b..00b9f06e 100644 --- a/secp256kfun/src/slice.rs +++ b/secp256kfun/src/slice.rs @@ -26,10 +26,7 @@ pub struct Slice<'a, S = Public> { impl<'a, S> Clone for Slice<'a, S> { fn clone(&self) -> Self { - Self { - inner: self.inner, - secrecy: PhantomData, - } + *self } } diff --git a/secp256kfun/src/vendor/k256/mul.rs b/secp256kfun/src/vendor/k256/mul.rs index 728b58fb..14e48fb8 100644 --- a/secp256kfun/src/vendor/k256/mul.rs +++ b/secp256kfun/src/vendor/k256/mul.rs @@ -411,9 +411,9 @@ impl MulAssign<&Scalar> for ProjectivePoint { /// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N` #[cfg(feature = "alloc")] -pub fn lincomb_iter<'a, 'b>( - xs: impl Iterator, - ks: impl Iterator, +pub fn lincomb_iter, P: AsRef>( + xs: impl Iterator, + ks: impl Iterator, ) -> ProjectivePoint { use alloc::vec::Vec; use subtle::ConditionallyNegatable; @@ -424,8 +424,9 @@ pub fn lincomb_iter<'a, 'b>( let mut tables2 = Vec::with_capacity(size); let mut n = 0; - for (k, mut x) in ks.zip(xs.cloned()) { - let (mut r1, mut r2) = decompose_scalar(&k); + for (k, x) in ks.zip(xs) { + let mut x = x.as_ref().clone(); + let (mut r1, mut r2) = decompose_scalar(k.as_ref()); let mut x_beta = x.endomorphism(); let r1_is_high = r1.is_high(); let r2_is_high = r2.is_high(); diff --git a/secp256kfun/src/vendor/k256/projective.rs b/secp256kfun/src/vendor/k256/projective.rs index a648dfe2..95dc3441 100644 --- a/secp256kfun/src/vendor/k256/projective.rs +++ b/secp256kfun/src/vendor/k256/projective.rs @@ -490,3 +490,9 @@ impl<'a> Neg for &'a ProjectivePoint { ProjectivePoint::neg(self) } } + +impl AsRef for ProjectivePoint { + fn as_ref(&self) -> &ProjectivePoint { + self + } +} diff --git a/secp256kfun/src/vendor/k256/scalar.rs b/secp256kfun/src/vendor/k256/scalar.rs index 1bc33b11..4e7d573a 100644 --- a/secp256kfun/src/vendor/k256/scalar.rs +++ b/secp256kfun/src/vendor/k256/scalar.rs @@ -429,3 +429,9 @@ impl From<&Scalar> for FieldBytes { scalar.to_bytes() } } + +impl AsRef for Scalar { + fn as_ref(&self) -> &Scalar { + self + } +} diff --git a/secp256kfun/tests/against_c_lib.rs b/secp256kfun/tests/against_c_lib.rs index 6e1fd441..19708100 100755 --- a/secp256kfun/tests/against_c_lib.rs +++ b/secp256kfun/tests/against_c_lib.rs @@ -51,10 +51,10 @@ mod against_c_lib { let result = { let H = g!({ Scalar::::from_bytes_mod_order(scalar_H) } * G); double_mul( - &Scalar::::from_bytes_mod_order(x).public(), + Scalar::::from_bytes_mod_order(x).public(), G, - &Scalar::::from_bytes_mod_order(y).public(), - &H, + Scalar::::from_bytes_mod_order(y).public(), + H, ) .normalize().non_zero() .unwrap() diff --git a/secp256kfun/tests/expression_macros.rs b/secp256kfun/tests/expression_macros.rs index d040e0b8..cbc8838b 100755 --- a/secp256kfun/tests/expression_macros.rs +++ b/secp256kfun/tests/expression_macros.rs @@ -17,53 +17,61 @@ fn s_expressions_give_correct_answers() { let b = s!(5); let c = s!(11); - assert_eq!(s!(a), &a); - assert_eq!(s!({ a.invert() }), a.invert()); + assert_eq!(s!(a), a); + assert_eq!(s!(a.invert()), a.invert()); assert_eq!(s!(-a), -&a); - assert_eq!(s!(a + b), op::scalar_add(&a, &b)); - assert_eq!(s!(-a + b), op::scalar_sub(&b, &a)); - assert_eq!(s!(a - b), op::scalar_sub(&a, &b)); - assert_eq!(s!({ a.invert() } * a), s!(1)); + assert_eq!(s!(a + b), op::scalar_add(a, b)); + assert_eq!(s!(-a + b), op::scalar_sub(b, a)); + assert_eq!(s!(a - b), op::scalar_sub(a, b)); + assert_eq!(s!(a.invert() * a), s!(1)); + assert_eq!(s!(a / a), s!(1)); + assert_eq!(s!(a / a * b), b); + assert_eq!(s!(a * b / a), b); assert_eq!( s!(a + b - c - a), - op::scalar_sub(&op::scalar_sub(&op::scalar_add(&a, &b), &c), &a) + op::scalar_sub(op::scalar_sub(op::scalar_add(a, b), c), a) ); - assert_eq!(s!(a - c * b), op::scalar_sub(&a, &op::scalar_mul(&c, &b))); + assert_eq!(s!(a - c * b), op::scalar_sub(a, op::scalar_mul(c, b))); assert_eq!( s!(a - c * b - a), - op::scalar_sub(&op::scalar_sub(&a, &op::scalar_mul(&c, &b)), &a) + op::scalar_sub(op::scalar_sub(a, op::scalar_mul(c, b)), a) ); assert_eq!( s!(a - c * b + a), - op::scalar_add(&op::scalar_sub(&a, &op::scalar_mul(&c, &b)), &a) - ); - assert_eq!(s!(a * b), op::scalar_mul(&a, &b)); - assert_eq!(s!(a * -b), op::scalar_mul(&a, &-&b)); - assert_eq!(s!(a * b - c), op::scalar_sub(&op::scalar_mul(&a, &b), &c)); - assert_eq!(s!(a * (b + c)), op::scalar_mul(&a, &op::scalar_add(&b, &c))); - assert_eq!( - s!(a * -(b + c)), - op::scalar_mul(&a, &-op::scalar_add(&b, &c)) + op::scalar_add(op::scalar_sub(a, op::scalar_mul(c, b)), a) ); + assert_eq!(s!(a * b), op::scalar_mul(a, b)); + assert_eq!(s!(a * -b), op::scalar_mul(a, -b)); + assert_eq!(s!(a * b - c), op::scalar_sub(op::scalar_mul(a, b), c)); + assert_eq!(s!(a * (b + c)), op::scalar_mul(a, op::scalar_add(b, c))); + assert_eq!(s!(a * -(b + c)), op::scalar_mul(a, -op::scalar_add(b, c))); assert_eq!( s!(a * -(b + c) + a), - op::scalar_add(&op::scalar_mul(&a, &-op::scalar_add(&b, &c)), &a) + op::scalar_add(op::scalar_mul(a, -op::scalar_add(b, c)), a) ); - assert_eq!(s!(a * b * c), op::scalar_mul(&a, &op::scalar_mul(&b, &c))); - assert_eq!(s!(-a * b * -c), op::scalar_mul(&a, &op::scalar_mul(&b, &c))); + assert_eq!(s!(a * b * c), op::scalar_mul(a, op::scalar_mul(b, c))); + assert_eq!(s!(-a * b * -c), op::scalar_mul(a, op::scalar_mul(b, c))); let has_scalar = Has { has: s!(17) }; - assert_eq!(s!(has_scalar.has * a), op::scalar_mul(&has_scalar.has, &a)); + assert_eq!(s!(has_scalar.has * a), op::scalar_mul(has_scalar.has, a)); let has_has_scalar = HasHas { has_has: has_scalar.clone(), }; assert_eq!( s!(has_has_scalar.has_has.has * a), - op::scalar_mul(&has_scalar.has, &a) + op::scalar_mul(has_scalar.has, a) ); assert_eq!(s!(3 * 11 + 5), s!(a * c + b)); + + let x = [a, b, c]; + let y = [s!(13), s!(17), s!(19)]; + + assert_eq!(s!(x .* y), s!(a * y[0] + b * y[1] + c * y[2])); + let ref_x = &x; + let ref_y = &y; + assert_eq!(s!(ref_x .* ref_y), s!(a * y[0] + b * y[1] + c * y[2])); } #[test] @@ -75,85 +83,61 @@ fn g_expressions_give_correct_answers() { let B = Point::random(&mut rand::thread_rng()); let C = Point::random(&mut rand::thread_rng()); - assert_eq!(g!(A), &A); + assert_eq!(g!(A), A); assert_eq!(g!(-A), -&A); - assert_eq!(g!(x * A), op::scalar_mul_point(&x, &A)); - assert_eq!(g!(-x * A), op::scalar_mul_point(&-&x, &A)); - assert_eq!(g!(A - B), op::point_sub(&A, &B)); - assert_eq!(g!(A + -B), op::point_sub(&A, &B)); - assert_eq!(g!(A + B), op::point_add(&A, &B)); - assert_eq!( - g!(x * A + B), - op::point_add(&op::scalar_mul_point(&x, &A), &B) - ); - assert_eq!( - g!(x * A - B), - op::point_sub(&op::scalar_mul_point(&x, &A), &B) - ); + assert_eq!(g!(x * A), op::scalar_mul_point(x, A)); + assert_eq!(g!(-x * A), op::scalar_mul_point(-x, A)); + assert_eq!(g!(A - B), op::point_sub(A, B)); + assert_eq!(g!(A + -B), op::point_sub(A, B)); + assert_eq!(g!(A + B), op::point_add(A, B)); + assert_eq!(g!(x * A + B), op::point_add(op::scalar_mul_point(x, A), B)); + assert_eq!(g!(x * A - B), op::point_sub(op::scalar_mul_point(x, A), B)); assert_eq!( g!(-x * A + B), - op::point_add(&op::scalar_mul_point(&-&x, &A), &B) + op::point_add(op::scalar_mul_point(-x, A), B) ); assert_eq!( g!(-x * A - B), - op::point_sub(&op::scalar_mul_point(&-&x, &A), &B) - ); - assert_eq!( - g!(A + x * B), - op::point_add(&A, &op::scalar_mul_point(&x, &B)) - ); - assert_eq!( - g!(A - x * B), - op::point_sub(&A, &op::scalar_mul_point(&x, &B)) + op::point_sub(op::scalar_mul_point(-x, A), B) ); - assert_eq!(g!(x * A + y * B), op::double_mul(&x, &A, &y, &B)); - assert_eq!(g!(x * A - y * B), op::double_mul(&x, &A, &-&y, &B)); + assert_eq!(g!(A + x * B), op::point_add(A, op::scalar_mul_point(x, B))); + assert_eq!(g!(A - x * B), op::point_sub(A, op::scalar_mul_point(x, B))); + assert_eq!(g!(x * A + y * B), op::double_mul(x, A, y, B)); + assert_eq!(g!(x * A - y * B), op::double_mul(x, A, -y, B)); assert_eq!( g!((x - x * y) * A + y * B), - op::double_mul(&op::scalar_sub(&x, &op::scalar_mul(&x, &y)), &A, &y, &B) + op::double_mul(op::scalar_sub(x, op::scalar_mul(x, y)), A, y, B) ); assert_eq!( g!(x * A + y * B + z * C), - op::point_add( - &op::double_mul(&x, &A, &y, &B), - &op::scalar_mul_point(&z, &C) - ) + op::point_add(op::double_mul(x, A, y, B), op::scalar_mul_point(z, C)) ); assert_eq!( g!(x * A - y * B + z * C), - op::point_add( - &op::double_mul(&x, &A, &-&y, &B), - &op::scalar_mul_point(&z, &C) - ) + op::point_add(op::double_mul(x, A, -y, B), op::scalar_mul_point(z, C)) ); assert_eq!( g!(x * A - y * B - z * C), - op::point_add( - &op::double_mul(&x, &A, &-&y, &B), - &op::scalar_mul_point(&-&z, &C) - ) + op::point_add(op::double_mul(x, A, -y, B), op::scalar_mul_point(-z, C)) ); assert_eq!( g!(x * A + y * B + C), - op::point_add(&op::double_mul(&x, &A, &y, &B), &C) + op::point_add(op::double_mul(x, A, y, B), C) ); assert_eq!( g!(x * A + y * B - C), - op::point_add(&op::double_mul(&x, &A, &y, &B), &-&C) + op::point_add(op::double_mul(x, A, y, B), -C) ); assert_eq!( g!(x * A - y * B + z * C), - op::point_add( - &op::double_mul(&x, &A, &-&y, &B), - &op::scalar_mul_point(&z, &C) - ) + op::point_add(op::double_mul(x, A, -y, B), op::scalar_mul_point(z, C)) ); let has_scalar = Has { has: s!(17) }; @@ -167,16 +151,16 @@ fn g_expressions_give_correct_answers() { assert_eq!( g!(has_scalar.has * A), - op::scalar_mul_point(&has_scalar.has, &A) + op::scalar_mul_point(has_scalar.has, A) ); assert_eq!( g!(has_has_scalar.has_has.has * A), - op::scalar_mul_point(&has_scalar.has, &A) + op::scalar_mul_point(has_scalar.has, A) ); assert_eq!( g!(x * has_point.has + y * has_has_point.has_has.has), - op::double_mul(&x, &has_point.has, &y, &has_has_point.has_has.has) + op::double_mul(x, has_point.has, y, has_has_point.has_has.has) ); } diff --git a/sigma_fun/src/or.rs b/sigma_fun/src/or.rs index 637c60ff..256582f0 100644 --- a/sigma_fun/src/or.rs +++ b/sigma_fun/src/or.rs @@ -214,14 +214,14 @@ mod test { let proof_system = crate::FiatShamir::>::default(); let proof_lhs = proof_system.prove( - &Either::Left(x.clone()), + &Either::Left(x), &statement, Some(&mut rand::thread_rng()), ); assert!(proof_system.verify(&statement, &proof_lhs)); let wrong_proof_lhs = proof_system.prove( - &Either::Right(x.clone()), + &Either::Right(x), &statement, Some(&mut rand::thread_rng()), );