From 439b243338b036e037efa63398e3d79dc59040fb Mon Sep 17 00:00:00 2001 From: moana Date: Thu, 8 Feb 2024 13:53:13 +0100 Subject: [PATCH] Adapt to latest changes in safe --- Cargo.toml | 1 + src/hades.rs | 15 ++-- src/hash.rs | 90 +++++++++++------------- src/lib.rs | 2 +- src/{sponge.rs => state.rs} | 38 +++++------ tests/hash.rs | 133 +++++++++++++++++++++++------------- 6 files changed, 151 insertions(+), 128 deletions(-) rename src/{sponge.rs => state.rs} (66%) diff --git a/Cargo.toml b/Cargo.toml index 7c6da54..b3bda36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ bytecheck = { version = "0.6", optional = true, default-features = false } criterion = "0.3" rand = { version = "0.8", default-features = false, features = ["getrandom", "std_rng"] } ff = { version = "0.13", default-features = false } +once_cell = "1" [features] zk = [ diff --git a/src/hades.rs b/src/hades.rs index eab8200..4790441 100644 --- a/src/hades.rs +++ b/src/hades.rs @@ -62,7 +62,7 @@ mod tests { use dusk_bls12_381::BlsScalar; use dusk_bytes::ParseHexStr; - use safe::{DomainSeparator, IOCall, Permutation, Sponge}; + use safe::{IOCall, Permutation, Sponge}; use crate::hades::{permute, WIDTH}; @@ -70,6 +70,8 @@ mod tests { struct State([BlsScalar; WIDTH]); impl Permutation for State { + const ZERO_VALUE: BlsScalar = BlsScalar::zero(); + fn state_mut(&mut self) -> &mut [BlsScalar; WIDTH] { &mut self.0 } @@ -86,11 +88,7 @@ mod tests { BlsScalar::zero() } - fn zero_value() -> BlsScalar { - BlsScalar::zero() - } - - fn add(&mut self, right: BlsScalar, left: BlsScalar) -> BlsScalar { + fn add(&mut self, right: &BlsScalar, left: &BlsScalar) -> BlsScalar { right + left } } @@ -122,8 +120,9 @@ mod tests { let state = State::new([BlsScalar::zero(); WIDTH]); - let mut sponge = - Sponge::start(state, iopattern, DomainSeparator::from(0)); + let domain_sep = 0; + let mut sponge = Sponge::start(state, iopattern, domain_sep) + .expect("IO pattern should be valid"); // absorb given input sponge .absorb(input.len(), input) diff --git a/src/hash.rs b/src/hash.rs index 1efd37e..e301341 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -7,10 +7,10 @@ use alloc::vec::Vec; use dusk_bls12_381::BlsScalar; use dusk_jubjub::JubJubScalar; -use safe::{DomainSeparator, IOCall, Sponge}; +use safe::{IOCall, Sponge}; use crate::hades::WIDTH; -use crate::sponge::HadesPermutation; +use crate::state::HadesState; /// The Domain Separation for Poseidon #[derive(Debug, Clone, Copy, PartialEq)] @@ -25,23 +25,23 @@ pub enum Domain { Other, } -impl From<&Domain> for DomainSeparator { - // Encryption for the DomainSeparator are taken from section 4.2 of the - // paper and adapted to u64. - // When no domain is selected, we use the specification for variable length - // hashing and set it in `Hash::finalize()` - fn from(domain: &Domain) -> DomainSeparator { - let encoding = match domain { +impl Domain { + /// Encryption for the domain-separator are taken from section 4.2 of the + /// paper adapted to u64. + /// When `Other` is selected we set the domain-separator to zero. We can do + /// this since the io-pattern will be encoded in the tag in any case, + /// ensuring safety from collision attacks. + pub const fn encoding(&self) -> u64 { + match self { // 2^4 - 1 Domain::Merkle4 => 0x0000_0000_0000_000f, // 2^2 - 1 Domain::Merkle2 => 0x0000_0000_0000_0003, - // 2^8 + // 2^32 Domain::Encryption => 0x0000_0001_0000_0000, // 0 Domain::Other => 0x0000_0000_0000_0000, - }; - DomainSeparator::from(encoding) + } } } @@ -49,7 +49,7 @@ fn io_pattern( domain: Domain, input: &Vec<&[T]>, output_len: usize, -) -> Result<(usize, Vec), safe::Error> { +) -> Result, safe::Error> { let mut io_pattern = Vec::new(); // check total input length against domain let input_len = input.iter().fold(0, |acc, input| acc + input.len()); @@ -71,30 +71,17 @@ fn io_pattern( } io_pattern.push(IOCall::Squeeze(output_len as u32)); - Ok((input_len, io_pattern)) -} - -fn domain_separator( - domain: Domain, - input_len: usize, - output_len: usize, -) -> DomainSeparator { - match domain { - // when the domain separator is not set, we calculate it from - // the input and output length: - Domain::Other => { - // input_len * 2^8 + output_len - 1 - let mut encoding = (input_len as u64) << 8; - encoding += output_len as u64 - 1; - DomainSeparator::from(encoding) - } - // in all other cases we use the encoding defined by the From trait - _ => DomainSeparator::from(&domain), - } + Ok(io_pattern) } -/// Hash struct. +/// Hash any given input into one or several scalar using the Hades +/// permutation strategy. The Hash can absorb multiple chunks of input but will +/// only call `squeeze` once at the finalization of the hash. +/// The output length is set to 1 element per default, but this can be +/// overridden with [`Hash::output_len`]. pub struct Hash<'a> { + // once the #[feature(adt_const_params)] becomes stable, we can turn the + // Domain into a const generic domain: Domain, input: Vec<&'a [BlsScalar]>, output_len: usize, @@ -123,20 +110,18 @@ impl<'a> Hash<'a> { /// Finalize the hash. pub fn finalize(&self) -> Result, safe::Error> { // generate the io-pattern - let (input_len, io_pattern) = - io_pattern(self.domain, &self.input, self.output_len)?; + let io_pattern = io_pattern(self.domain, &self.input, self.output_len)?; - // get the domain-separator - let domain_sep = - domain_separator(self.domain, input_len, self.output_len); + // set the domain-separator + let domain_sep = self.domain.encoding(); // Generate the hash using the sponge framework. // initialize the sponge let mut sponge = Sponge::start( - HadesPermutation::new([BlsScalar::zero(); WIDTH]), + HadesState::new([BlsScalar::zero(); WIDTH]), io_pattern, domain_sep, - ); + )?; // absorb the input for input in self.input.iter() { sponge.absorb(input.len(), input)?; @@ -150,21 +135,23 @@ impl<'a> Hash<'a> { .expect("The io-pattern should not be violated")) } - /// Finalize the hash and output JubJubScalar. + /// Finalize the hash and output JubJubScalar by truncating the `BlsScalar` + /// output to 250 bits. pub fn finalize_truncated(&self) -> Result, safe::Error> { // finalize the hash as bls-scalar let bls_output = self.finalize()?; - // output the result as jubjub-scalar by truncating the two MSBs - let bls_mask = BlsScalar::from_raw([ + // 'cast' the bls-scalar result to a jubjub-scalar by truncating the 6 + // highest bits + let bit_mask = BlsScalar::from_raw([ 0xffff_ffff_ffff_ffff, 0xffff_ffff_ffff_ffff, 0xffff_ffff_ffff_ffff, - 0x3fff_ffff_ffff_ffff, + 0x03ff_ffff_ffff_ffff, ]); Ok(bls_output .iter() - .map(|bls| JubJubScalar::from_raw((bls & &bls_mask).reduce().0)) + .map(|bls| JubJubScalar::from_raw((bls & &bit_mask).reduce().0)) .collect()) } @@ -192,7 +179,7 @@ impl<'a> Hash<'a> { #[cfg(feature = "zk")] pub(crate) mod zk { use super::*; - use crate::sponge::zk::HadesPermutationGadget; + use crate::state::zk::HadesStateGadget; use dusk_plonk::prelude::*; @@ -229,20 +216,19 @@ pub(crate) mod zk { composer: &mut Composer, ) -> Result, safe::Error> { // generate the io-pattern - let (input_len, io_pattern) = + let io_pattern = io_pattern(self.domain, &self.input, self.output_len)?; // get the domain-separator - let domain_sep = - domain_separator(self.domain, input_len, self.output_len); + let domain_sep = self.domain.encoding(); // Generate the hash using the sponge framework. // initialize the sponge let mut sponge = Sponge::start( - HadesPermutationGadget::new(composer, [Composer::ZERO; WIDTH]), + HadesStateGadget::new(composer, [Composer::ZERO; WIDTH]), io_pattern, domain_sep, - ); + )?; // absorb the input for input in self.input.iter() { sponge.absorb(input.len(), input)?; diff --git a/src/lib.rs b/src/lib.rs index 1edcdb6..f5fc7d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ extern crate alloc; mod cipher; mod hades; mod hash; -mod sponge; +mod state; #[cfg(feature = "cipher")] pub use cipher::{ diff --git a/src/sponge.rs b/src/state.rs similarity index 66% rename from src/sponge.rs rename to src/state.rs index ed8fd59..1388bc9 100644 --- a/src/sponge.rs +++ b/src/state.rs @@ -7,35 +7,33 @@ use dusk_bls12_381::BlsScalar; use safe::Permutation; -use crate::hades::{permute, WIDTH}; +use crate::hades::{self, WIDTH}; -pub(crate) struct HadesPermutation { +pub(crate) struct HadesState { state: [BlsScalar; WIDTH], } -impl Permutation for HadesPermutation { +impl Permutation for HadesState { + const ZERO_VALUE: BlsScalar = BlsScalar::zero(); + fn state_mut(&mut self) -> &mut [BlsScalar; WIDTH] { &mut self.state } fn permute(&mut self) { - permute(self.state_mut()) + hades::permute(self.state_mut()) } fn tag(&mut self, input: &[u8]) -> BlsScalar { BlsScalar::hash_to_scalar(input) } - fn zero_value() -> BlsScalar { - BlsScalar::zero() - } - - fn add(&mut self, right: BlsScalar, left: BlsScalar) -> BlsScalar { + fn add(&mut self, right: &BlsScalar, left: &BlsScalar) -> BlsScalar { right + left } } -impl HadesPermutation { +impl HadesState { pub fn new(state: [BlsScalar; WIDTH]) -> Self { Self { state } } @@ -47,14 +45,16 @@ pub(crate) mod zk { use dusk_plonk::prelude::*; use safe::Permutation; - use crate::hades::{permute_gadget, WIDTH}; + use crate::hades::{self, WIDTH}; - pub(crate) struct HadesPermutationGadget<'a> { + pub(crate) struct HadesStateGadget<'a> { composer: &'a mut Composer, state: [Witness; WIDTH], } - impl<'a> Permutation for HadesPermutationGadget<'a> { + impl<'a> Permutation for HadesStateGadget<'a> { + const ZERO_VALUE: Witness = Composer::ZERO; + fn state_mut(&mut self) -> &mut [Witness; WIDTH] { &mut self.state } @@ -62,7 +62,7 @@ pub(crate) mod zk { fn permute(&mut self) { let mut state = [Composer::ZERO; WIDTH]; state.copy_from_slice(&self.state); - permute_gadget(self.composer, &mut state); + hades::permute_gadget(self.composer, &mut state); self.state.copy_from_slice(&state); } @@ -71,18 +71,14 @@ pub(crate) mod zk { self.composer.append_witness(tag) } - fn zero_value() -> Witness { - Composer::ZERO - } - - fn add(&mut self, right: Witness, left: Witness) -> Witness { + fn add(&mut self, right: &Witness, left: &Witness) -> Witness { let constraint = - Constraint::new().left(1).a(left).right(1).b(right); + Constraint::new().left(1).a(*left).right(1).b(*right); self.composer.gate_add(constraint) } } - impl<'a> HadesPermutationGadget<'a> { + impl<'a> HadesStateGadget<'a> { pub fn new( composer: &'a mut Composer, state: [Witness; WIDTH], diff --git a/tests/hash.rs b/tests/hash.rs index dc0e0f4..3fbb61a 100644 --- a/tests/hash.rs +++ b/tests/hash.rs @@ -6,8 +6,7 @@ #![cfg(feature = "zk")] -use std::vec; - +use once_cell::sync::Lazy; use rand::rngs::StdRng; use rand::SeedableRng; @@ -16,18 +15,50 @@ use dusk_plonk::prelude::*; use dusk_poseidon::{Domain, Hash, HashGadget}; use ff::Field; -const CAPACITY: usize = 12; +static PUB_PARAMS: Lazy = Lazy::new(|| { + let mut rng = StdRng::seed_from_u64(0xbeef); + + const CAPACITY: usize = 12; + PublicParameters::setup(1 << CAPACITY, &mut rng) + .expect("Cannot initialize Public Parameters") +}); + +fn compile_and_verify(rng: &mut StdRng, circuit: C) -> Result<(), Error> +where + C: Circuit, +{ + let label = b"hash-gadget-tester"; + let (prover, verifier) = Compiler::compile::(&PUB_PARAMS, label)?; + + let (proof, _public_inputs) = prover.prove(rng, &circuit)?; + + let public_inputs = Vec::new(); + verifier.verify(&proof, &public_inputs) +} + +// ---------------- +// Test normal hash +// ---------------- -#[derive(Default, Debug)] -pub struct TestCircuit { - input: Vec, +#[derive(Debug)] +struct TestCircuit { + input: [BlsScalar; L], output: BlsScalar, } -impl TestCircuit { - pub fn random(rng: &mut StdRng, input_len: usize) -> Self { +impl Default for TestCircuit { + fn default() -> Self { + Self { + input: [BlsScalar::zero(); L], + output: BlsScalar::zero(), + } + } +} + +impl TestCircuit { + pub fn random(rng: &mut StdRng) -> Self { // create random input - let mut input = vec![BlsScalar::zero(); input_len]; + let mut input = [BlsScalar::zero(); L]; input .iter_mut() .for_each(|s| *s = BlsScalar::random(&mut *rng)); @@ -43,10 +74,10 @@ impl TestCircuit { } } -impl Circuit for TestCircuit { +impl Circuit for TestCircuit { fn circuit(&self, composer: &mut Composer) -> Result<(), PlonkError> { // append input to the circuit - let mut input_witnesses = vec![Composer::ZERO; self.input.len()]; + let mut input_witnesses = [Composer::ZERO; L]; self.input .iter() .zip(input_witnesses.iter_mut()) @@ -68,40 +99,50 @@ impl Circuit for TestCircuit { #[test] fn test_gadget() -> Result<(), Error> { - let label = b"hash-gadget-tester"; let mut rng = StdRng::seed_from_u64(0xbeef); - let pp = PublicParameters::setup(1 << CAPACITY, &mut rng)?; - for input_len in [3, 5, 15] { - let circuit = TestCircuit::random(&mut rng, input_len); + // test for input of 3 scalar + let circuit = TestCircuit::<3>::random(&mut rng); + compile_and_verify(&mut rng, circuit)?; - let (prover, verifier) = - Compiler::compile_with_circuit(&pp, label, &circuit)?; + // test for input of 5 scalar + let circuit = TestCircuit::<5>::random(&mut rng); + compile_and_verify(&mut rng, circuit)?; - let (proof, public_inputs) = prover.prove(&mut rng, &circuit)?; + // test for input of 15 scalar + let circuit = TestCircuit::<15>::random(&mut rng); + compile_and_verify(&mut rng, circuit) +} - verifier.verify(&proof, &public_inputs)?; - } +// ------------------- +// Test truncated hash +// ------------------- - Ok(()) +#[derive(Debug)] +struct TestTruncatedCircuit { + input: [BlsScalar; L], + output: JubJubScalar, } -#[derive(Default, Debug)] -pub struct TestTruncatedCircuit { - input: Vec, - output: BlsScalar, +impl Default for TestTruncatedCircuit { + fn default() -> Self { + Self { + input: [BlsScalar::zero(); L], + output: JubJubScalar::zero(), + } + } } -impl TestTruncatedCircuit { - pub fn random(rng: &mut StdRng, input_len: usize) -> Self { +impl TestTruncatedCircuit { + pub fn random(rng: &mut StdRng) -> Self { // create random input - let mut input = vec![BlsScalar::zero(); input_len]; + let mut input = [BlsScalar::zero(); L]; input .iter_mut() .for_each(|s| *s = BlsScalar::random(&mut *rng)); // calculate expected hash output - let output = Hash::digest(Domain::Other, &input) + let output = Hash::digest_truncated(Domain::Other, &input) .expect("hash creation should not fail"); Self { @@ -111,10 +152,10 @@ impl TestTruncatedCircuit { } } -impl Circuit for TestTruncatedCircuit { +impl Circuit for TestTruncatedCircuit { fn circuit(&self, composer: &mut Composer) -> Result<(), PlonkError> { // append input to the circuit - let mut input_witnesses = vec![Composer::ZERO; self.input.len()]; + let mut input_witnesses = [Composer::ZERO; L]; self.input .iter() .zip(input_witnesses.iter_mut()) @@ -127,9 +168,12 @@ impl Circuit for TestTruncatedCircuit { // check that the gadget result is as expected let mut hash = HashGadget::new(Domain::Other); hash.update(&input_witnesses); - let gadget_output = - HashGadget::digest(Domain::Other, composer, &input_witnesses) - .expect("hash creation should not fail"); + let gadget_output = HashGadget::digest_truncated( + Domain::Other, + composer, + &input_witnesses, + ) + .expect("hash creation should not fail"); composer.assert_equal(gadget_output[0], expected_output); Ok(()) @@ -138,20 +182,17 @@ impl Circuit for TestTruncatedCircuit { #[test] fn test_truncated_gadget() -> Result<(), Error> { - let label = b"truncated-tester"; let mut rng = StdRng::seed_from_u64(0xbeef); - let pp = PublicParameters::setup(1 << CAPACITY, &mut rng)?; - for input_len in [3, 5, 15] { - let circuit = TestTruncatedCircuit::random(&mut rng, input_len); + // test for input of 3 scalar + let circuit = TestTruncatedCircuit::<3>::random(&mut rng); + compile_and_verify(&mut rng, circuit)?; - let (prover, verifier) = - Compiler::compile_with_circuit(&pp, label, &circuit)?; - - let (proof, public_inputs) = prover.prove(&mut rng, &circuit)?; - - verifier.verify(&proof, &public_inputs)?; - } + // test for input of 5 scalar + let circuit = TestTruncatedCircuit::<5>::random(&mut rng); + compile_and_verify(&mut rng, circuit)?; - Ok(()) + // test for input of 15 scalar + let circuit = TestTruncatedCircuit::<15>::random(&mut rng); + compile_and_verify(&mut rng, circuit) }