Skip to content

Commit

Permalink
Adapt to latest changes in safe
Browse files Browse the repository at this point in the history
  • Loading branch information
moCello committed Feb 16, 2024
1 parent 0042cb3 commit 439b243
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 128 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ bytecheck = { version = "0.6", optional = true, default-features = false }
criterion = "0.3"
rand = { version = "0.8", default-features = false, features = ["getrandom", "std_rng"] }
ff = { version = "0.13", default-features = false }
once_cell = "1"

[features]
zk = [
Expand Down
15 changes: 7 additions & 8 deletions src/hades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,16 @@ mod tests {

use dusk_bls12_381::BlsScalar;
use dusk_bytes::ParseHexStr;
use safe::{DomainSeparator, IOCall, Permutation, Sponge};
use safe::{IOCall, Permutation, Sponge};

use crate::hades::{permute, WIDTH};

#[derive(Default, Debug, Clone, Copy, PartialEq)]
struct State([BlsScalar; WIDTH]);

impl Permutation<BlsScalar, WIDTH> for State {
const ZERO_VALUE: BlsScalar = BlsScalar::zero();

fn state_mut(&mut self) -> &mut [BlsScalar; WIDTH] {
&mut self.0
}
Expand All @@ -86,11 +88,7 @@ mod tests {
BlsScalar::zero()
}

fn zero_value() -> BlsScalar {
BlsScalar::zero()
}

fn add(&mut self, right: BlsScalar, left: BlsScalar) -> BlsScalar {
fn add(&mut self, right: &BlsScalar, left: &BlsScalar) -> BlsScalar {
right + left
}
}
Expand Down Expand Up @@ -122,8 +120,9 @@ mod tests {

let state = State::new([BlsScalar::zero(); WIDTH]);

let mut sponge =
Sponge::start(state, iopattern, DomainSeparator::from(0));
let domain_sep = 0;
let mut sponge = Sponge::start(state, iopattern, domain_sep)
.expect("IO pattern should be valid");
// absorb given input
sponge
.absorb(input.len(), input)
Expand Down
90 changes: 38 additions & 52 deletions src/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
use alloc::vec::Vec;
use dusk_bls12_381::BlsScalar;
use dusk_jubjub::JubJubScalar;
use safe::{DomainSeparator, IOCall, Sponge};
use safe::{IOCall, Sponge};

use crate::hades::WIDTH;
use crate::sponge::HadesPermutation;
use crate::state::HadesState;

/// The Domain Separation for Poseidon
#[derive(Debug, Clone, Copy, PartialEq)]
Expand All @@ -25,31 +25,31 @@ pub enum Domain {
Other,
}

impl From<&Domain> for DomainSeparator {
// Encryption for the DomainSeparator are taken from section 4.2 of the
// paper and adapted to u64.
// When no domain is selected, we use the specification for variable length
// hashing and set it in `Hash::finalize()`
fn from(domain: &Domain) -> DomainSeparator {
let encoding = match domain {
impl Domain {
/// Encryption for the domain-separator are taken from section 4.2 of the
/// paper adapted to u64.
/// When `Other` is selected we set the domain-separator to zero. We can do
/// this since the io-pattern will be encoded in the tag in any case,
/// ensuring safety from collision attacks.
pub const fn encoding(&self) -> u64 {
match self {
// 2^4 - 1
Domain::Merkle4 => 0x0000_0000_0000_000f,
// 2^2 - 1
Domain::Merkle2 => 0x0000_0000_0000_0003,
// 2^8
// 2^32
Domain::Encryption => 0x0000_0001_0000_0000,
// 0
Domain::Other => 0x0000_0000_0000_0000,
};
DomainSeparator::from(encoding)
}
}
}

fn io_pattern<T>(
domain: Domain,
input: &Vec<&[T]>,
output_len: usize,
) -> Result<(usize, Vec<IOCall>), safe::Error> {
) -> Result<Vec<IOCall>, safe::Error> {
let mut io_pattern = Vec::new();
// check total input length against domain
let input_len = input.iter().fold(0, |acc, input| acc + input.len());
Expand All @@ -71,30 +71,17 @@ fn io_pattern<T>(
}
io_pattern.push(IOCall::Squeeze(output_len as u32));

Ok((input_len, io_pattern))
}

fn domain_separator(
domain: Domain,
input_len: usize,
output_len: usize,
) -> DomainSeparator {
match domain {
// when the domain separator is not set, we calculate it from
// the input and output length:
Domain::Other => {
// input_len * 2^8 + output_len - 1
let mut encoding = (input_len as u64) << 8;
encoding += output_len as u64 - 1;
DomainSeparator::from(encoding)
}
// in all other cases we use the encoding defined by the From trait
_ => DomainSeparator::from(&domain),
}
Ok(io_pattern)
}

/// Hash struct.
/// Hash any given input into one or several scalar using the Hades
/// permutation strategy. The Hash can absorb multiple chunks of input but will
/// only call `squeeze` once at the finalization of the hash.
/// The output length is set to 1 element per default, but this can be
/// overridden with [`Hash::output_len`].
pub struct Hash<'a> {
// once the #[feature(adt_const_params)] becomes stable, we can turn the
// Domain into a const generic
domain: Domain,
input: Vec<&'a [BlsScalar]>,
output_len: usize,
Expand Down Expand Up @@ -123,20 +110,18 @@ impl<'a> Hash<'a> {
/// Finalize the hash.
pub fn finalize(&self) -> Result<Vec<BlsScalar>, safe::Error> {
// generate the io-pattern
let (input_len, io_pattern) =
io_pattern(self.domain, &self.input, self.output_len)?;
let io_pattern = io_pattern(self.domain, &self.input, self.output_len)?;

// get the domain-separator
let domain_sep =
domain_separator(self.domain, input_len, self.output_len);
// set the domain-separator
let domain_sep = self.domain.encoding();

// Generate the hash using the sponge framework.
// initialize the sponge
let mut sponge = Sponge::start(
HadesPermutation::new([BlsScalar::zero(); WIDTH]),
HadesState::new([BlsScalar::zero(); WIDTH]),
io_pattern,
domain_sep,
);
)?;
// absorb the input
for input in self.input.iter() {
sponge.absorb(input.len(), input)?;
Expand All @@ -150,21 +135,23 @@ impl<'a> Hash<'a> {
.expect("The io-pattern should not be violated"))
}

/// Finalize the hash and output JubJubScalar.
/// Finalize the hash and output JubJubScalar by truncating the `BlsScalar`
/// output to 250 bits.
pub fn finalize_truncated(&self) -> Result<Vec<JubJubScalar>, safe::Error> {
// finalize the hash as bls-scalar
let bls_output = self.finalize()?;

// output the result as jubjub-scalar by truncating the two MSBs
let bls_mask = BlsScalar::from_raw([
// 'cast' the bls-scalar result to a jubjub-scalar by truncating the 6
// highest bits
let bit_mask = BlsScalar::from_raw([
0xffff_ffff_ffff_ffff,
0xffff_ffff_ffff_ffff,
0xffff_ffff_ffff_ffff,
0x3fff_ffff_ffff_ffff,
0x03ff_ffff_ffff_ffff,
]);
Ok(bls_output
.iter()
.map(|bls| JubJubScalar::from_raw((bls & &bls_mask).reduce().0))
.map(|bls| JubJubScalar::from_raw((bls & &bit_mask).reduce().0))
.collect())
}

Expand Down Expand Up @@ -192,7 +179,7 @@ impl<'a> Hash<'a> {
#[cfg(feature = "zk")]
pub(crate) mod zk {
use super::*;
use crate::sponge::zk::HadesPermutationGadget;
use crate::state::zk::HadesStateGadget;

use dusk_plonk::prelude::*;

Expand Down Expand Up @@ -229,20 +216,19 @@ pub(crate) mod zk {
composer: &mut Composer,
) -> Result<Vec<Witness>, safe::Error> {
// generate the io-pattern
let (input_len, io_pattern) =
let io_pattern =
io_pattern(self.domain, &self.input, self.output_len)?;

// get the domain-separator
let domain_sep =
domain_separator(self.domain, input_len, self.output_len);
let domain_sep = self.domain.encoding();

// Generate the hash using the sponge framework.
// initialize the sponge
let mut sponge = Sponge::start(
HadesPermutationGadget::new(composer, [Composer::ZERO; WIDTH]),
HadesStateGadget::new(composer, [Composer::ZERO; WIDTH]),
io_pattern,
domain_sep,
);
)?;
// absorb the input
for input in self.input.iter() {
sponge.absorb(input.len(), input)?;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ extern crate alloc;
mod cipher;
mod hades;
mod hash;
mod sponge;
mod state;

#[cfg(feature = "cipher")]
pub use cipher::{
Expand Down
38 changes: 17 additions & 21 deletions src/sponge.rs → src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,33 @@
use dusk_bls12_381::BlsScalar;
use safe::Permutation;

use crate::hades::{permute, WIDTH};
use crate::hades::{self, WIDTH};

pub(crate) struct HadesPermutation {
pub(crate) struct HadesState {
state: [BlsScalar; WIDTH],
}

impl Permutation<BlsScalar, WIDTH> for HadesPermutation {
impl Permutation<BlsScalar, WIDTH> for HadesState {
const ZERO_VALUE: BlsScalar = BlsScalar::zero();

fn state_mut(&mut self) -> &mut [BlsScalar; WIDTH] {
&mut self.state
}

fn permute(&mut self) {
permute(self.state_mut())
hades::permute(self.state_mut())
}

fn tag(&mut self, input: &[u8]) -> BlsScalar {
BlsScalar::hash_to_scalar(input)
}

fn zero_value() -> BlsScalar {
BlsScalar::zero()
}

fn add(&mut self, right: BlsScalar, left: BlsScalar) -> BlsScalar {
fn add(&mut self, right: &BlsScalar, left: &BlsScalar) -> BlsScalar {
right + left
}
}

impl HadesPermutation {
impl HadesState {
pub fn new(state: [BlsScalar; WIDTH]) -> Self {
Self { state }
}
Expand All @@ -47,22 +45,24 @@ pub(crate) mod zk {
use dusk_plonk::prelude::*;
use safe::Permutation;

use crate::hades::{permute_gadget, WIDTH};
use crate::hades::{self, WIDTH};

pub(crate) struct HadesPermutationGadget<'a> {
pub(crate) struct HadesStateGadget<'a> {
composer: &'a mut Composer,
state: [Witness; WIDTH],
}

impl<'a> Permutation<Witness, WIDTH> for HadesPermutationGadget<'a> {
impl<'a> Permutation<Witness, WIDTH> for HadesStateGadget<'a> {
const ZERO_VALUE: Witness = Composer::ZERO;

fn state_mut(&mut self) -> &mut [Witness; WIDTH] {
&mut self.state
}

fn permute(&mut self) {
let mut state = [Composer::ZERO; WIDTH];
state.copy_from_slice(&self.state);
permute_gadget(self.composer, &mut state);
hades::permute_gadget(self.composer, &mut state);
self.state.copy_from_slice(&state);
}

Expand All @@ -71,18 +71,14 @@ pub(crate) mod zk {
self.composer.append_witness(tag)
}

fn zero_value() -> Witness {
Composer::ZERO
}

fn add(&mut self, right: Witness, left: Witness) -> Witness {
fn add(&mut self, right: &Witness, left: &Witness) -> Witness {
let constraint =
Constraint::new().left(1).a(left).right(1).b(right);
Constraint::new().left(1).a(*left).right(1).b(*right);
self.composer.gate_add(constraint)
}
}

impl<'a> HadesPermutationGadget<'a> {
impl<'a> HadesStateGadget<'a> {
pub fn new(
composer: &'a mut Composer,
state: [Witness; WIDTH],
Expand Down
Loading

0 comments on commit 439b243

Please sign in to comment.