Skip to content

Commit

Permalink
Use precomputation on input set keys
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronFeickert committed Mar 21, 2024
1 parent 54c97b1 commit 81600b9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ categories = ["cryptography"]
blake3 = { version = "1.5.0", default-features = false }
crypto-bigint = { version = "0.5.5", default-features = false }
curve25519-dalek = { version = "4.1.2", default-features = false, features = ["alloc", "digest", "rand_core", "zeroize"] }
derivative = { version = "2.2.0", default-features = false, features = ["use_core"] }
itertools = { version = "0.12.1", default-features = false }
merlin = { version = "3.0.0", default-features = false }
rand_core = { version = "0.6.4", default-features = false }
Expand Down
14 changes: 9 additions & 5 deletions src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use core::{iter::once, slice, slice::ChunksExact};

use curve25519_dalek::{
ristretto::CompressedRistretto,
traits::{Identity, MultiscalarMul, VartimeMultiscalarMul},
traits::{Identity, MultiscalarMul, VartimePrecomputedMultiscalarMul},
RistrettoPoint,
Scalar,
};
Expand Down Expand Up @@ -174,6 +174,7 @@ impl Proof {
let r = witness.get_r();
let l = witness.get_l();
let M = statement.get_input_set().get_keys();
let precomputation = statement.get_input_set().get_precomputation();
let params = statement.get_params();
let J = statement.get_J();

Expand Down Expand Up @@ -327,7 +328,11 @@ impl Proof {

match timing {
OperationTiming::Constant => RistrettoPoint::multiscalar_mul(X_scalars, X_points),
OperationTiming::Variable => RistrettoPoint::vartime_multiscalar_mul(X_scalars, X_points),
OperationTiming::Variable => precomputation.vartime_mixed_multiscalar_mul(
X_scalars.take(M.len()),
once(rho),
once(params.get_G()),
),
}
})
.collect::<Vec<RistrettoPoint>>();
Expand Down Expand Up @@ -549,6 +554,7 @@ impl Proof {

// Extract common values for convenience
let M = first_statement.get_input_set().get_keys();
let precomputation = first_statement.get_input_set().get_precomputation();
let params = first_statement.get_params();

// Check that all proof semantics are valid for the statement
Expand Down Expand Up @@ -604,7 +610,6 @@ impl Proof {
.chain(once(params.get_G()))
.chain(params.get_CommitmentG().iter())
.chain(once(params.get_CommitmentH()))
.chain(M.iter())
.chain(once(params.get_U()))
.collect::<Vec<&RistrettoPoint>>();

Expand Down Expand Up @@ -743,11 +748,10 @@ impl Proof {
scalars.push(G_scalar);
scalars.extend(CommitmentG_scalars);
scalars.push(CommitmentH_scalar);
scalars.extend(M_scalars);
scalars.push(U_scalar);

// Perform the final check; this can be done in variable time since it holds no secrets
if RistrettoPoint::vartime_multiscalar_mul(scalars.iter(), points) == RistrettoPoint::identity() {
if precomputation.vartime_mixed_multiscalar_mul(M_scalars, scalars, points) == RistrettoPoint::identity() {
Ok(())
} else {
Err(ProofError::FailedVerification)
Expand Down
28 changes: 26 additions & 2 deletions src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
use alloc::{sync::Arc, vec::Vec};

use blake3::Hasher;
use curve25519_dalek::{traits::Identity, RistrettoPoint};
use curve25519_dalek::{
ristretto::VartimeRistrettoPrecomputation,
traits::{Identity, VartimePrecomputedMultiscalarMul},
RistrettoPoint,
};
use derivative::Derivative;
use snafu::prelude::*;

use crate::parameters::Parameters;
Expand All @@ -14,12 +19,25 @@ use crate::parameters::Parameters;
/// An input set is constructed from a vector of verification keys.
/// Internally, it also contains cryptographic hash data to make proofs more efficient.
#[allow(non_snake_case)]
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct InputSet {
#[derivative(Debug = "ignore")]
M: Vec<RistrettoPoint>,
#[derivative(Debug = "ignore")]
precomputation: Arc<VartimeRistrettoPrecomputation>,
hash: Vec<u8>,
}

impl Eq for InputSet {}

impl PartialEq for InputSet {
fn eq(&self, other: &Self) -> bool {
// This only checks hashes for efficiency, which is fine given the constructor
self.hash == other.hash
}
}

impl InputSet {
// Version identifier used for hashing
const VERSION: u64 = 0;
Expand All @@ -37,6 +55,7 @@ impl InputSet {

Self {
M: M.to_vec(),
precomputation: Arc::new(VartimeRistrettoPrecomputation::new(M)),
hash: hasher.finalize().as_bytes().to_vec(),
}
}
Expand Down Expand Up @@ -68,6 +87,11 @@ impl InputSet {
&self.M
}

/// Get the precomputation for this [`InputSet`].
pub(crate) fn get_precomputation(&self) -> &VartimeRistrettoPrecomputation {
&self.precomputation
}

/// Get a cryptographic hash representation of this [`InputSet`], suitable for transcripting.
pub(crate) fn get_hash(&self) -> &[u8] {
&self.hash
Expand Down

0 comments on commit 81600b9

Please sign in to comment.