diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 40df3e2b..a9f59a2e 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -49,7 +49,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: nightly-2023-04-24 + toolchain: nightly-2024-02-14 components: clippy override: true - name: Run Clippy diff --git a/Cargo.toml b/Cargo.toml index c002d68b..b5622a6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,24 +7,28 @@ authors = ["Leo Lara "] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[patch.crates-io] +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v0.3.0" } + +[patch."https://github.com/scroll-tech/halo2.git"] +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v0.3.0" } + + [dependencies] pyo3 = { version = "0.19.1", features = ["extension-module"] } halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", features = [ "circuit-params", -], tag = "v2023_04_20" } -halo2curves = { git = 'https://github.com/privacy-scaling-explorations/halo2curves', tag = "0.3.2", features = [ "derive_serde", -] } -polyexen = { git = "https://github.com/Dhole/polyexen.git", rev = "4d128ad2ebd0094160ea77e30fb9ce56abb854e0" } +], tag = "v0.3.0" } + +polyexen = { git = "https://github.com/Dhole/polyexen.git", rev = "16a85c5411f804dc49bbf373d24ff9eedadedfbe" } num-bigint = { version = "0.4", features = ["rand"] } uuid = { version = "1.4.0", features = ["v1", "rng"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +hyperplonk_benchmark = { git = "https://github.com/qwang98/plonkish.git", branch = "main", package = "benchmark" } +plonkish_backend = { git = "https://github.com/qwang98/plonkish.git", branch = "main", package = "plonkish_backend" } +regex = "1" [dev-dependencies] rand_chacha = "0.3" - -[patch."https://github.com/privacy-scaling-explorations/halo2.git"] -halo2_proofs = { git = "https://github.com/appliedzkp/halo2.git", rev = "d3746109d7d38be53afc8ddae8fdfaf1f02ad1d7", features = [ - "circuit-params", -] } diff --git a/examples/blake2f.rs b/examples/blake2f.rs new file mode 100644 index 00000000..1ce2db7f --- /dev/null +++ b/examples/blake2f.rs @@ -0,0 +1,1499 @@ +use chiquito::{ + frontend::dsl::{ + cb::{eq, select, table}, + lb::LookupTable, + super_circuit, CircuitContext, StepTypeSetupContext, StepTypeWGHandler, + }, + plonkish::{ + backend::halo2::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, + compiler::{ + cell_manager::{MaxWidthCellManager, SingleRowCellManager}, + config, + step_selector::SimpleStepSelectorBuilder, + }, + ir::sc::SuperCircuit, + }, + poly::ToExpr, + sbpir::query::Queriable, +}; +use halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, group::ff::PrimeField}, +}; +use std::{fmt::Write, hash::Hash}; + +pub const IV_LEN: usize = 8; +pub const SIGMA_VECTOR_LENGTH: usize = 16; +pub const SIGMA_VECTOR_NUMBER: usize = 10; +pub const R1: u64 = 32; +pub const R2: u64 = 24; +pub const R3: u64 = 16; +pub const R4: u64 = 63; +pub const MIXING_ROUNDS: u64 = 12; +pub const SPLIT_64BITS: u64 = 16; +pub const BASE_4BITS: u64 = 16; +pub const XOR_4SPLIT_64BITS: u64 = SPLIT_64BITS * SPLIT_64BITS; +pub const V_LEN: usize = 16; +pub const M_LEN: usize = 16; +pub const H_LEN: usize = 8; +pub const G_ROUNDS: u64 = 16; + +pub const IV_VALUES: [u64; IV_LEN] = [ + 0x6A09E667F3BCC908, + 0xBB67AE8584CAA73B, + 0x3C6EF372FE94F82B, + 0xA54FF53A5F1D36F1, + 0x510E527FADE682D1, + 0x9B05688C2B3E6C1F, + 0x1F83D9ABFB41BD6B, + 0x5BE0CD19137E2179, +]; + +pub const SIGMA_VALUES: [[usize; SIGMA_VECTOR_LENGTH]; SIGMA_VECTOR_NUMBER] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +pub const XOR_VALUES: [u8; XOR_4SPLIT_64BITS as usize] = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, + 12, 15, 14, 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13, 3, 2, 1, 0, 7, 6, 5, 4, 11, + 10, 9, 8, 15, 14, 13, 12, 4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11, 5, 4, 7, 6, 1, + 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15, 12, 13, 10, 11, 8, 9, 7, + 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, + 5, 6, 7, 9, 8, 11, 10, 13, 12, 15, 14, 1, 0, 3, 2, 5, 4, 7, 6, 10, 11, 8, 9, 14, 15, 12, 13, 2, + 3, 0, 1, 6, 7, 4, 5, 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4, 12, 13, 14, 15, 8, + 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3, 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 14, + 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, + 2, 1, 0, +]; + +pub fn string_to_u64(inputs: [&str; 4]) -> [u64; 4] { + inputs + .iter() + .map(|&input| { + assert_eq!(16, input.len()); + u64::from_le_bytes( + (0..input.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&input[i..i + 2], 16).unwrap()) + .collect::>() + .try_into() + .unwrap(), + ) + }) + .collect::>() + .try_into() + .unwrap() +} + +pub fn u64_to_string(inputs: &[u64; 4]) -> [String; 4] { + inputs + .iter() + .map(|input| { + let mut s = String::new(); + for byte in input.to_le_bytes() { + write!(&mut s, "{:02x}", byte).expect("Unable to write"); + } + s + }) + .collect::>() + .try_into() + .unwrap() +} + +pub fn split_to_4bits_values(vec_values: &[u64]) -> Vec> { + vec_values + .iter() + .map(|&value| { + let mut value = value; + (0..SPLIT_64BITS) + .map(|_| { + let v = value % BASE_4BITS; + value >>= 4; + F::from(v) + }) + .collect() + }) + .collect() +} + +fn blake2f_iv_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + let lookup_iv_row: Queriable = ctx.fixed("iv row"); + let lookup_iv_value: Queriable = ctx.fixed("iv value"); + + let iv_values = IV_VALUES; + ctx.pragma_num_steps(IV_LEN); + ctx.fixed_gen(move |ctx| { + for (i, &value) in iv_values.iter().enumerate() { + ctx.assign(i, lookup_iv_row, F::from(i as u64)); + ctx.assign(i, lookup_iv_value, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_iv_row).add(lookup_iv_value)) +} + +// For range checking +fn blake2f_4bits_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + let lookup_4bits_row: Queriable = ctx.fixed("4bits row"); + let lookup_4bits_value: Queriable = ctx.fixed("4bits value"); + + ctx.pragma_num_steps(SPLIT_64BITS as usize); + ctx.fixed_gen(move |ctx| { + for i in 0..SPLIT_64BITS as usize { + ctx.assign(i, lookup_4bits_row, F::ONE); + ctx.assign(i, lookup_4bits_value, F::from(i as u64)); + } + }); + + ctx.new_table(table().add(lookup_4bits_row).add(lookup_4bits_value)) +} + +fn blake2f_xor_4bits_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + let lookup_xor_row: Queriable = ctx.fixed("xor row"); + let lookup_xor_value: Queriable = ctx.fixed("xor value"); + + ctx.pragma_num_steps((SPLIT_64BITS * SPLIT_64BITS) as usize); + let xor_values = XOR_VALUES; + ctx.fixed_gen(move |ctx| { + for (i, &value) in xor_values.iter().enumerate() { + ctx.assign(i, lookup_xor_row, F::from(i as u64)); + ctx.assign(i, lookup_xor_value, F::from(value as u64)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_value)) +} + +#[derive(Clone, Copy)] +struct CircuitParams { + pub iv_table: LookupTable, + pub bits_table: LookupTable, + pub xor_4bits_table: LookupTable, +} + +impl CircuitParams { + fn check_4bit( + self, + ctx: &mut StepTypeSetupContext, + bits: Queriable, + ) { + ctx.add_lookup(self.bits_table.apply(1).apply(bits)); + } + + fn check_3bit( + self, + ctx: &mut StepTypeSetupContext, + bits: Queriable, + ) { + ctx.add_lookup(self.bits_table.apply(1).apply(bits)); + ctx.add_lookup(self.bits_table.apply(1).apply(bits * 2)); + } + + fn check_xor( + self, + ctx: &mut StepTypeSetupContext, + b1: Queriable, + b2: Queriable, + xor: Queriable, + ) { + ctx.add_lookup(self.xor_4bits_table.apply(b1 * BASE_4BITS + b2).apply(xor)); + } + + fn check_not( + self, + ctx: &mut StepTypeSetupContext, + b1: Queriable, + xor: Queriable, + ) { + ctx.add_lookup(self.xor_4bits_table.apply(b1 * BASE_4BITS + 0xF).apply(xor)); + } + + fn check_iv( + self, + ctx: &mut StepTypeSetupContext, + i: usize, + iv: Queriable, + ) { + ctx.add_lookup(self.iv_table.apply(i).apply(iv)); + } +} + +struct PreInput { + round: F, + t0: F, + t1: F, + f: F, + v_vec: Vec, + h_vec: Vec, + m_vec: Vec, + h_split_4bits_vec: Vec>, + m_split_4bits_vec: Vec>, + t_split_4bits_vec: Vec>, + iv_split_4bits_vec: Vec>, + final_split_bits_vec: Vec>, +} + +struct GInput { + round: F, + v_vec: Vec, + h_vec: Vec, + m_vec: Vec, + v_mid1_vec: Vec, + v_mid2_vec: Vec, + v_mid3_vec: Vec, + v_mid4_vec: Vec, + v_mid_va_bit_vec: Vec>, + v_mid_vb_bit_vec: Vec>, + v_mid_vc_bit_vec: Vec>, + v_mid_vd_bit_vec: Vec>, + v_xor_d_bit_vec: Vec>, + v_xor_b_bit_vec: Vec>, + b_bit_vec: Vec, + b_3bits_vec: Vec, +} + +struct FinalInput { + round: F, + v_vec: Vec, + h_vec: Vec, + output_vec: Vec, + v_split_bit_vec: Vec>, + h_split_bit_vec: Vec>, + v_xor_split_bit_vec: Vec>, + final_split_bit_vec: Vec>, +} + +struct InputValues { + pub round: u32, // 32bit + pub h_vec: [u64; H_LEN], // 8 * 64bits + pub m_vec: [u64; M_LEN], // 16 * 64bits + pub t0: u64, // 64bits + pub t1: u64, // 64bits + pub f: bool, // 8bits +} + +struct GStepParams { + m_vec: Vec>, + v_mid_va_bit_vec: Vec>, + v_mid_vb_bit_vec: Vec>, + v_mid_vc_bit_vec: Vec>, + v_mid_vd_bit_vec: Vec>, + v_xor_b_bit_vec: Vec>, + v_xor_d_bit_vec: Vec>, + input_vec: Vec>, + output_vec: Vec>, + b_bit: Queriable, + b_3bits: Queriable, +} + +fn split_value_4bits(mut value: u128, n: u64) -> Vec { + (0..n) + .map(|_| { + let v = value % BASE_4BITS as u128; + value /= BASE_4BITS as u128; + + F::from(v as u64) + }) + .collect() +} + +fn split_xor_value(value1: u64, value2: u64) -> Vec { + let mut value1 = value1; + let mut value2 = value2; + let bit_values: Vec = (0..64) + .map(|_| { + let b1 = value1 % 2; + value1 /= 2; + let b2 = value2 % 2; + value2 /= 2; + b1 ^ b2 + }) + .collect(); + (0..SPLIT_64BITS as usize) + .map(|i| { + F::from( + bit_values[i * 4] + + bit_values[i * 4 + 1] * 2 + + bit_values[i * 4 + 2] * 4 + + bit_values[i * 4 + 3] * 8, + ) + }) + .collect() +} + +fn g_wg( + (v1_vec_values, v2_vec_values): (&mut [u64], &mut [u64]), + (a, b, c, d): (usize, usize, usize, usize), + (x, y): (u64, u64), + (va_bit_vec, vb_bit_vec): (&mut Vec>, &mut Vec>), + (vc_bit_vec, vd_bit_vec): (&mut Vec>, &mut Vec>), + (v_xor_d_bit_vec, v_xor_b_bit_vec): (&mut Vec>, &mut Vec>), + (b_bit_vec, b_3bits_vec): (&mut Vec, &mut Vec), +) { + va_bit_vec.push(split_value_4bits( + v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + x as u128, + SPLIT_64BITS + 1, + )); + v1_vec_values[a] = (v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + x as u128) as u64; + + vd_bit_vec.push(split_value_4bits(v1_vec_values[d] as u128, SPLIT_64BITS)); + v1_vec_values[d] = ((v1_vec_values[d] ^ v1_vec_values[a]) >> R1) + ^ (v1_vec_values[d] ^ v1_vec_values[a]) << (64 - R1); + v_xor_d_bit_vec.push(split_value_4bits(v1_vec_values[d] as u128, SPLIT_64BITS)); + + vc_bit_vec.push(split_value_4bits( + v1_vec_values[c] as u128 + v1_vec_values[d] as u128, + SPLIT_64BITS + 1, + )); + v1_vec_values[c] = (v1_vec_values[c] as u128 + v1_vec_values[d] as u128) as u64; + + vb_bit_vec.push(split_value_4bits(v1_vec_values[b] as u128, SPLIT_64BITS)); + v1_vec_values[b] = ((v1_vec_values[b] ^ v1_vec_values[c]) >> R2) + ^ (v1_vec_values[b] ^ v1_vec_values[c]) << (64 - R2); + v_xor_b_bit_vec.push(split_value_4bits(v1_vec_values[b] as u128, SPLIT_64BITS)); + + va_bit_vec.push(split_value_4bits( + v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + y as u128, + SPLIT_64BITS + 1, + )); + v2_vec_values[a] = (v1_vec_values[a] as u128 + v1_vec_values[b] as u128 + y as u128) as u64; + + vd_bit_vec.push(split_value_4bits(v1_vec_values[d] as u128, SPLIT_64BITS)); + v2_vec_values[d] = ((v1_vec_values[d] ^ v2_vec_values[a]) >> R3) + ^ (v1_vec_values[d] ^ v2_vec_values[a]) << (64 - R3); + v_xor_d_bit_vec.push(split_value_4bits(v2_vec_values[d] as u128, SPLIT_64BITS)); + + vc_bit_vec.push(split_value_4bits( + v1_vec_values[c] as u128 + v2_vec_values[d] as u128, + SPLIT_64BITS + 1, + )); + v2_vec_values[c] = (v1_vec_values[c] as u128 + v2_vec_values[d] as u128) as u64; + + vb_bit_vec.push(split_value_4bits(v1_vec_values[b] as u128, SPLIT_64BITS)); + v2_vec_values[b] = ((v1_vec_values[b] ^ v2_vec_values[c]) >> R4) + ^ (v1_vec_values[b] ^ v2_vec_values[c]) << (64 - R4); + v_xor_b_bit_vec.push(split_value_4bits( + (v1_vec_values[b] ^ v2_vec_values[c]) as u128, + SPLIT_64BITS, + )); + let bits = (v1_vec_values[b] ^ v2_vec_values[c]) / 2u64.pow(60); + b_bit_vec.push(F::from(bits / 8)); + b_3bits_vec.push(F::from(bits % 8)) +} + +fn split_4bit_signals( + ctx: &mut StepTypeSetupContext, + params: &CircuitParams, + input: &[Queriable], + output: &[Vec>], +) { + for (i, split_vec) in output.iter().enumerate() { + let mut sum_value = 0.expr() * 1; + + for &bits in split_vec.iter().rev() { + params.check_4bit(ctx, bits); + sum_value = sum_value * BASE_4BITS + bits; + } + ctx.constr(eq(sum_value, input[i])) + } +} + +// We check G function one time by calling twice g_setup function.c +// Because the G function can be divided into two similar parts. +fn g_setup( + ctx: &mut StepTypeSetupContext<'_, F>, + params: CircuitParams, + q_params: GStepParams, + (a, b, c, d): (usize, usize, usize, usize), + (move1, move2): (u64, u64), + s: usize, + flag: bool, +) { + let mut a_bits_sum_value = 0.expr() * 1; + let mut a_bits_sum_mod_value = 0.expr() * 1; + for (j, &bits) in q_params.v_mid_va_bit_vec.iter().rev().enumerate() { + a_bits_sum_value = a_bits_sum_value * BASE_4BITS + bits; + if j != 0 { + a_bits_sum_mod_value = a_bits_sum_mod_value * BASE_4BITS + bits; + } + params.check_4bit(ctx, bits); + } + // check v_mid_va_bit_vec = 4bit split of v[a] + v[b] + x + ctx.constr(eq( + a_bits_sum_value, + q_params.input_vec[a] + q_params.input_vec[b] + q_params.m_vec[s], + )); + // check v[a] = (v[a] + v[b] + x) mod 2^64 + ctx.constr(eq(a_bits_sum_mod_value, q_params.output_vec[a])); + + // check d_bits_sum_value = 4bit split of v[b] + let mut d_bits_sum_value = 0.expr() * 1; + for &bits in q_params.v_mid_vd_bit_vec.iter().rev() { + d_bits_sum_value = d_bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + ctx.constr(eq(d_bits_sum_value, q_params.input_vec[d])); + + let mut ad_xor_sum_value = 0.expr() * 1; + for &bits in q_params.v_xor_d_bit_vec.iter().rev() { + ad_xor_sum_value = ad_xor_sum_value * BASE_4BITS + bits; + } + // check v_xor_d_bit_vec = 4bit split of v[d] + ctx.constr(eq(ad_xor_sum_value, q_params.output_vec[d])); + // check v_xor_d_bit_vec[i] = (v[d][i] ^ v[a][i]) >>> R1(or R3) + for j in 0..SPLIT_64BITS as usize { + params.check_xor( + ctx, + q_params.v_mid_va_bit_vec[j], + q_params.v_mid_vd_bit_vec[j], + q_params.v_xor_d_bit_vec + [(j + BASE_4BITS as usize - move1 as usize) % BASE_4BITS as usize], + ); + } + + // check v[c] = (v[c] + v[d]) mod 2^64 + let mut c_bits_sum_value = 0.expr() * 1; + let mut c_bits_sum_mod_value = 0.expr() * 1; + for (j, &bits) in q_params.v_mid_vc_bit_vec.iter().rev().enumerate() { + c_bits_sum_value = c_bits_sum_value * BASE_4BITS + bits; + if j != 0 { + c_bits_sum_mod_value = c_bits_sum_mod_value * BASE_4BITS + bits; + } + params.check_4bit(ctx, bits); + } + // check v_mid_vc_bit_vec = 4bit split of (v[c] + v[d]) + ctx.constr(eq( + c_bits_sum_value, + q_params.input_vec[c] + q_params.output_vec[d], + )); + // check v[c] = (v[c] + v[d] ) mod 2^64 + ctx.constr(eq(c_bits_sum_mod_value, q_params.output_vec[c])); + + let mut b_bits_sum_value = 0.expr() * 1; + for &bits in q_params.v_mid_vb_bit_vec.iter().rev() { + b_bits_sum_value = b_bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + + // v_mid_vb_bit_vec = 4bit split of v[b] + ctx.constr(eq(b_bits_sum_value, q_params.input_vec[b])); + let mut bc_xor_sum_value = 0.expr() * 1; + for (j, &bits) in q_params.v_xor_b_bit_vec.iter().rev().enumerate() { + if j == 0 && flag { + // b_bit * 8 + b_3bits = v_xor_b_bit_vec[0] + bc_xor_sum_value = q_params.b_3bits * 1; + ctx.constr(eq(q_params.b_bit * 8 + q_params.b_3bits, bits)); + } else { + bc_xor_sum_value = bc_xor_sum_value * BASE_4BITS + bits; + } + params.check_4bit(ctx, bits); + } + if flag { + bc_xor_sum_value = bc_xor_sum_value * 2 + q_params.b_bit; + + ctx.constr(eq(q_params.b_bit * (q_params.b_bit - 1), 0)); + // To constraint b_3bits_vec[i/2] \in [0..8) + params.check_3bit(ctx, q_params.b_3bits); + } + // check v_xor_b_bit_vec = v[b] + ctx.constr(eq(bc_xor_sum_value, q_params.output_vec[b])); + + // check v_xor_b_bit_vec[i] = (v[b][i] ^ v[c][i]) >>> R2(or R4) + for j in 0..SPLIT_64BITS as usize { + params.check_xor( + ctx, + q_params.v_mid_vb_bit_vec[j], + q_params.v_mid_vc_bit_vec[j], + q_params.v_xor_b_bit_vec + [(j + BASE_4BITS as usize - move2 as usize) % BASE_4BITS as usize], + ); + } +} + +fn blake2f_circuit( + ctx: &mut CircuitContext, + params: CircuitParams, +) { + let v_vec: Vec> = (0..V_LEN) + .map(|i| ctx.forward(format!("v_vec[{}]", i).as_str())) + .collect(); + let h_vec: Vec> = (0..H_LEN) + .map(|i| ctx.forward(format!("h_vec[{}]", i).as_str())) + .collect(); + let m_vec: Vec> = (0..M_LEN) + .map(|i| ctx.forward(format!("m_vec[{}]", i).as_str())) + .collect(); + let round = ctx.forward("round"); + + let blake2f_pre_step = ctx.step_type_def("blake2f_pre_step", |ctx| { + let v_vec = v_vec.clone(); + let wg_v_vec = v_vec.clone(); + + let h_vec = h_vec.clone(); + let wg_h_vec = h_vec.clone(); + + let m_vec = m_vec.clone(); + let wg_m_vec = m_vec.clone(); + + let t0 = ctx.internal("t0"); + let t1 = ctx.internal("t1"); + let f = ctx.internal("f"); + + // h_split_4bits_vec = 4bit split of h_vec + let h_split_4bits_vec: Vec>> = (0..H_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("h_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_h_split_4bits_vec = h_split_4bits_vec.clone(); + + // m_split_4bits_vec = 4bit split of m_vec + let m_split_4bits_vec: Vec>> = (0..M_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("m_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_m_split_4bits_vec = m_split_4bits_vec.clone(); + + // t_split_4bits_vec = 4bit split of t0 and t1 + let t_split_4bits_vec: Vec>> = (0..2) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("t_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_t_split_4bits_vec = t_split_4bits_vec.clone(); + + // iv_split_4bits_vec = 4bit split of IV[5], IV[6], IV[7] + let iv_split_4bits_vec: Vec>> = (0..3) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("iv_split_4bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_iv_split_4bits_vec = iv_split_4bits_vec.clone(); + + // final_split_bits_vec = 4bit split of IV[5] xor t0, IV[6] xor t1, IV[7] xor + // 0xFFFFFFFFFFFFFFFF, + let final_split_bits_vec: Vec>> = (0..3) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("final_split_bits_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_final_split_bits_vec = final_split_bits_vec.clone(); + + ctx.setup(move |ctx| { + // check inputs: h_vec + split_4bit_signals(ctx, ¶ms, &h_vec, &h_split_4bits_vec); + + // check inputs: m_vec + split_4bit_signals(ctx, ¶ms, &m_vec, &m_split_4bits_vec); + + // check inputs: t0,t1 + split_4bit_signals(ctx, ¶ms, &[t0, t1], &t_split_4bits_vec); + + // check input f + ctx.constr(eq(f * (f - 1), 0)); + + // check v_vec + for i in 0..H_LEN { + ctx.constr(eq(v_vec[i], h_vec[i])); + } + for (i, &iv) in v_vec[V_LEN / 2..V_LEN].iter().enumerate() { + params.check_iv(ctx, i, iv); + } + + // check the split-fields of v[12], v[13], v[14] + split_4bit_signals(ctx, ¶ms, &v_vec[12..15], &iv_split_4bits_vec); + + // check v[12] := v[12] ^ (t mod 2**w) + // check v[13] := v[13] ^ (t >> w) + for (i, (final_plit_bits_value, (iv_split_bits_value, t_split_bits_value))) in + final_split_bits_vec + .iter() + .zip(iv_split_4bits_vec.iter().zip(t_split_4bits_vec.iter())) + .enumerate() + .take(2) + { + let mut final_bits_sum_value = 0.expr() * 1; + for (&value, (&iv, &t)) in final_plit_bits_value.iter().rev().zip( + iv_split_bits_value + .iter() + .rev() + .zip(t_split_bits_value.iter().rev()), + ) { + params.check_xor(ctx, iv, t, value); + final_bits_sum_value = final_bits_sum_value * BASE_4BITS + value; + } + ctx.constr(eq(final_bits_sum_value, v_vec[12 + i].next())) + } + + // check if f, v[14] = v[14] ^ 0xffffffffffffffff else v[14] + let mut final_bits_sum_value = 0.expr() * 1; + for (&bits, &iv) in final_split_bits_vec[2] + .iter() + .rev() + .zip(iv_split_4bits_vec[2].iter().rev()) + { + params.check_not(ctx, iv, bits); + final_bits_sum_value = final_bits_sum_value * BASE_4BITS + bits; + } + + // check v_vec v_vec.next + for &v in v_vec.iter().take(12) { + ctx.transition(eq(v, v.next())); + } + ctx.transition(eq( + select(f, final_bits_sum_value, v_vec[14]), + v_vec[14].next(), + )); + ctx.transition(eq(v_vec[15], v_vec[15].next())); + // check h_vec h_vec.next + for &h in h_vec.iter() { + ctx.transition(eq(h, h.next())); + } + // check m_vec m_vec.next + for &m in m_vec.iter() { + ctx.transition(eq(m, m.next())); + } + + ctx.constr(eq(round, 0)); + ctx.transition(eq(round, round.next())); + }); + + ctx.wg(move |ctx, inputs: PreInput| { + ctx.assign(round, inputs.round); + ctx.assign(t0, inputs.t0); + ctx.assign(t1, inputs.t1); + ctx.assign(f, inputs.f); + for (&q, &v) in wg_v_vec.iter().zip(inputs.v_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_h_vec.iter().zip(inputs.h_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_m_vec.iter().zip(inputs.m_vec.iter()) { + ctx.assign(q, v) + } + for (q_vec, v_vec) in wg_h_split_4bits_vec + .iter() + .zip(inputs.h_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_m_split_4bits_vec + .iter() + .zip(inputs.m_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_t_split_4bits_vec + .iter() + .zip(inputs.t_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_iv_split_4bits_vec + .iter() + .zip(inputs.iv_split_4bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_final_split_bits_vec + .iter() + .zip(inputs.final_split_bits_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + }) + }); + + let blake2f_g_setup_vec: Vec> = (0..MIXING_ROUNDS as usize) + .map(|r| { + ctx.step_type_def(format!("blake2f_g_setup_{}", r), |ctx| { + let v_vec = v_vec.clone(); + let wg_v_vec = v_vec.clone(); + let h_vec = h_vec.clone(); + let wg_h_vec = h_vec.clone(); + let m_vec = m_vec.clone(); + let wg_m_vec = m_vec.clone(); + + // v_mid1_vec is the new v_vec after the first round call to the g_setup function + let v_mid1_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid1_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid1_vec = v_mid1_vec.clone(); + + // v_mid2_vec is the new v_vec after the second round call to the g_setup function + let v_mid2_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid2_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid2_vec = v_mid2_vec.clone(); + + // v_mid3_vec is the new v_vec after the third round to the g_setup function + let v_mid3_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid3_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid3_vec = v_mid3_vec.clone(); + + // v_mid4_vec is the new v_vec after the forth round to the g_setup function,as + // well as the final result of v_vec + let v_mid4_vec: Vec> = (0..V_LEN) + .map(|i| ctx.internal(format!("v_mid4_vec[{}]", i).as_str())) + .collect(); + let wg_v_mid4_vec = v_mid4_vec.clone(); + + // v_mid_va_bit_vec = 4bit split of v[a] + v[b] + x(or y) + let v_mid_va_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS + 1) + .map(|j| { + ctx.internal(format!("v_mid_va_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_va_bit_vec = v_mid_va_bit_vec.clone(); + + // v_mid_vd_bit_vec = 4bit split of v[d] + let v_mid_vd_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_mid_vd_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_vd_bit_vec = v_mid_vd_bit_vec.clone(); + + // v_mid_vc_bit_vec = 4bit split of v[c] + v[d] + let v_mid_vc_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS + 1) + .map(|j| { + ctx.internal(format!("v_mid_vc_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_vc_bit_vec = v_mid_vc_bit_vec.clone(); + + // v_mid_vb_bit_vec = 4bit split of v[b] + let v_mid_vb_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_mid_vb_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_mid_vb_bit_vec = v_mid_vb_bit_vec.clone(); + + // v_xor_d_bit_vec = 4bit split of (v[d] ^ v[a]) >>> R1(or R3) + let v_xor_d_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_xor_d_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_xor_d_bit_vec = v_xor_d_bit_vec.clone(); + + // v_xor_b_bit_vec = 4bit split of (v[b] ^ v[c]) >>> R2(or R4) + let v_xor_b_bit_vec: Vec>> = (0..G_ROUNDS) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| { + ctx.internal(format!("v_xor_b_bit_vec[{}][{}]", i, j).as_str()) + }) + .collect() + }) + .collect(); + let wg_v_xor_b_bit_vec = v_xor_b_bit_vec.clone(); + + // b_bit_vec[i] * 8 + b_3bits_vec[i] = v_xor_b_bit_vec[i * 2 + 1][0] + // the step of v[b] := (v[b] ^ v[c]) >>> R4 needs to split a 4-bit value to a + // one-bit value and a 3-bit value + let b_bit_vec: Vec> = (0..G_ROUNDS / 2) + .map(|i| ctx.internal(format!("b_bit_vec[{}]", i).as_str())) + .collect(); + let wg_b_bit_vec = b_bit_vec.clone(); + let b_3bits_vec: Vec> = (0..G_ROUNDS / 2) + .map(|i| ctx.internal(format!("b_3bits_vec[{}]", i).as_str())) + .collect(); + let wg_b_3bits_vec = b_3bits_vec.clone(); + + ctx.setup(move |ctx| { + let s = SIGMA_VALUES[r % 10]; + + for i in 0..G_ROUNDS as usize { + let mut input_vec = v_vec.clone(); + let mut output_vec = v_mid1_vec.clone(); + if i >= 8 { + if i % 2 == 0 { + input_vec = v_mid2_vec.clone(); + output_vec = v_mid3_vec.clone(); + } else { + input_vec = v_mid3_vec.clone(); + output_vec = v_mid4_vec.clone(); + } + } else if i % 2 == 1 { + input_vec = v_mid1_vec.clone(); + output_vec = v_mid2_vec.clone(); + } + let (mut a, mut b, mut c, mut d) = + (i / 2, 4 + i / 2, 8 + i / 2, 12 + i / 2); + if i / 2 == 4 { + (a, b, c, d) = (0, 5, 10, 15); + } else if i / 2 == 5 { + (a, b, c, d) = (1, 6, 11, 12); + } else if i / 2 == 6 { + (a, b, c, d) = (2, 7, 8, 13); + } else if i / 2 == 7 { + (a, b, c, d) = (3, 4, 9, 14); + } + let mut move1 = R1 / 4; + let mut move2 = R2 / 4; + if i % 2 == 1 { + move1 = R3 / 4; + move2 = (R4 + 1) / 4; + } + let q_params = GStepParams { + input_vec, + output_vec, + m_vec: m_vec.clone(), + v_mid_va_bit_vec: v_mid_va_bit_vec[i].clone(), + v_mid_vb_bit_vec: v_mid_vb_bit_vec[i].clone(), + v_mid_vc_bit_vec: v_mid_vc_bit_vec[i].clone(), + v_mid_vd_bit_vec: v_mid_vd_bit_vec[i].clone(), + v_xor_b_bit_vec: v_xor_b_bit_vec[i].clone(), + v_xor_d_bit_vec: v_xor_d_bit_vec[i].clone(), + b_bit: b_bit_vec[i / 2], + b_3bits: b_3bits_vec[i / 2], + }; + g_setup( + ctx, + params, + q_params, + (a, b, c, d), + (move1, move2), + s[i], + i % 2 == 1, + ); + } + + // check v_vec v_vec.next() + for (&v, &new_v) in v_vec.iter().zip(v_mid4_vec.iter()) { + ctx.transition(eq(new_v, v.next())); + } + // check h_vec h_vec.next() + for &h in h_vec.iter() { + ctx.transition(eq(h, h.next())); + } + // check m_vec m_vec.next() + if r < MIXING_ROUNDS as usize - 1 { + for &m in m_vec.iter() { + ctx.transition(eq(m, m.next())); + } + } + ctx.transition(eq(round + 1, round.next())); + }); + + ctx.wg(move |ctx, inputs: GInput| { + ctx.assign(round, inputs.round); + for (&q, &v) in wg_v_vec.iter().zip(inputs.v_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_h_vec.iter().zip(inputs.h_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_m_vec.iter().zip(inputs.m_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid1_vec.iter().zip(inputs.v_mid1_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid2_vec.iter().zip(inputs.v_mid2_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid3_vec.iter().zip(inputs.v_mid3_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_v_mid4_vec.iter().zip(inputs.v_mid4_vec.iter()) { + ctx.assign(q, v) + } + for (q_vec, v_vec) in wg_v_mid_va_bit_vec + .iter() + .zip(inputs.v_mid_va_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_mid_vb_bit_vec + .iter() + .zip(inputs.v_mid_vb_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_mid_vc_bit_vec + .iter() + .zip(inputs.v_mid_vc_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_mid_vd_bit_vec + .iter() + .zip(inputs.v_mid_vd_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in + wg_v_xor_d_bit_vec.iter().zip(inputs.v_xor_d_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in + wg_v_xor_b_bit_vec.iter().zip(inputs.v_xor_b_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (&q, &v) in wg_b_bit_vec.iter().zip(inputs.b_bit_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_b_3bits_vec.iter().zip(inputs.b_3bits_vec.iter()) { + ctx.assign(q, v) + } + }) + }) + }) + .collect(); + + let blake2f_final_step = ctx.step_type_def("blake2f_final_step", |ctx| { + let v_vec = v_vec.clone(); + let wg_v_vec = v_vec.clone(); + + let h_vec = h_vec.clone(); + let wg_h_vec = h_vec.clone(); + + let output_vec = m_vec.clone(); + let wg_output_vec = output_vec.clone(); + + // v_split_bit_vec = 4bit split of v_vec + let v_split_bit_vec: Vec>> = (0..V_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("v_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_v_split_bit_vec = v_split_bit_vec.clone(); + + // h_split_bit_vec = 4bit split of h_vec + let h_split_bit_vec: Vec>> = (0..H_LEN) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("h_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_h_split_bit_vec = h_split_bit_vec.clone(); + + // v_xor_split_bit_vec = 4bit split of v[i] ^ v[i + 8] + let v_xor_split_bit_vec: Vec>> = (0..8) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("v_xor_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_v_xor_split_bit_vec = v_xor_split_bit_vec.clone(); + + // final_split_bit_vec = 4bit split of h[i] ^ v[i] ^ v[i + 8] + let final_split_bit_vec: Vec>> = (0..8) + .map(|i| { + (0..SPLIT_64BITS) + .map(|j| ctx.internal(format!("v_xor_split_bit_vec[{}][{}]", i, j).as_str())) + .collect() + }) + .collect(); + let wg_final_split_bit_vec = final_split_bit_vec.clone(); + + ctx.setup(move |ctx| { + // check split-fields of v_vec + for (&v, v_split) in v_vec.iter().zip(v_split_bit_vec.iter()) { + let mut v_4bits_sum_value = 0.expr() * 1; + for &bits in v_split.iter().rev() { + v_4bits_sum_value = v_4bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + ctx.constr(eq(v_4bits_sum_value, v)); + } + + // check split-fields of h_vec + for (&h, h_split) in h_vec.iter().zip(h_split_bit_vec.iter()) { + let mut h_4bits_sum_value = 0.expr() * 1; + for &bits in h_split.iter().rev() { + h_4bits_sum_value = h_4bits_sum_value * BASE_4BITS + bits; + params.check_4bit(ctx, bits); + } + ctx.constr(eq(h_4bits_sum_value, h)); + } + + // check split-fields of v[i] ^ v[i+8] + for (xor_vec, (v1_vec, v2_vec)) in v_xor_split_bit_vec.iter().zip( + v_split_bit_vec[0..V_LEN / 2] + .iter() + .zip(v_split_bit_vec[V_LEN / 2..V_LEN].iter()), + ) { + for (&xor, (&v1, &v2)) in xor_vec.iter().zip(v1_vec.iter().zip(v2_vec.iter())) { + params.check_xor(ctx, v1, v2, xor); + } + } + + // check split-fields of h[i] ^ v[i] ^ v[i+8] + for (final_vec, (xor_vec, h_vec)) in final_split_bit_vec + .iter() + .zip(v_xor_split_bit_vec.iter().zip(h_split_bit_vec.iter())) + { + for (&value, (&v1, &v2)) in final_vec.iter().zip(xor_vec.iter().zip(h_vec.iter())) { + params.check_xor(ctx, v1, v2, value); + } + } + + // check output = h[i] ^ v[i] ^ v[i+8] + for (final_vec, &output) in final_split_bit_vec.iter().zip(output_vec.iter()) { + let mut final_4bits_sum_value = 0.expr() * 1; + for &value in final_vec.iter().rev() { + final_4bits_sum_value = final_4bits_sum_value * BASE_4BITS + value; + } + ctx.constr(eq(output, final_4bits_sum_value)); + } + ctx.constr(eq(round, MIXING_ROUNDS)); + }); + + ctx.wg(move |ctx, inputs: FinalInput| { + ctx.assign(round, inputs.round); + for (&q, &v) in wg_v_vec.iter().zip(inputs.v_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_h_vec.iter().zip(inputs.h_vec.iter()) { + ctx.assign(q, v) + } + for (&q, &v) in wg_output_vec.iter().zip(inputs.output_vec.iter()) { + ctx.assign(q, v) + } + for (q_vec, v_vec) in wg_v_split_bit_vec.iter().zip(inputs.v_split_bit_vec.iter()) { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_h_split_bit_vec.iter().zip(inputs.h_split_bit_vec.iter()) { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_v_xor_split_bit_vec + .iter() + .zip(inputs.v_xor_split_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + for (q_vec, v_vec) in wg_final_split_bit_vec + .iter() + .zip(inputs.final_split_bit_vec.iter()) + { + for (&q, &v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(q, v) + } + } + }) + }); + + ctx.pragma_first_step(&blake2f_pre_step); + ctx.pragma_last_step(&blake2f_final_step); + ctx.pragma_num_steps(MIXING_ROUNDS as usize + 2); + + ctx.trace(move |ctx, values| { + let h_vec_values = values.h_vec.to_vec(); + let h_split_4bits_vec = split_to_4bits_values::(&h_vec_values); + + let m_vec_values = values.m_vec.to_vec(); + let m_split_4bits_vec = split_to_4bits_values::(&m_vec_values); + + let mut iv_vec_values = IV_VALUES.to_vec(); + let iv_split_4bits_vec: Vec> = split_to_4bits_values::(&iv_vec_values[4..7]); + + let mut v_vec_values = h_vec_values.clone(); + v_vec_values.append(&mut iv_vec_values); + + let t_split_4bits_vec = split_to_4bits_values::(&[values.t0, values.t1]); + + let final_values = vec![ + v_vec_values[12] ^ values.t0, + v_vec_values[13] ^ values.t1, + v_vec_values[14] ^ 0xFFFFFFFFFFFFFFFF, + ]; + let final_split_bits_vec = split_to_4bits_values::(&final_values); + + let pre_inputs = PreInput { + round: F::ZERO, + t0: F::from(values.t0), + t1: F::from(values.t1), + f: F::from(if values.f { 1 } else { 0 }), + h_vec: h_vec_values.iter().map(|&v| F::from(v)).collect(), + m_vec: m_vec_values.iter().map(|&v| F::from(v)).collect(), + v_vec: v_vec_values.iter().map(|&v| F::from(v)).collect(), + h_split_4bits_vec, + m_split_4bits_vec, + t_split_4bits_vec, + iv_split_4bits_vec, + final_split_bits_vec, + }; + ctx.add(&blake2f_pre_step, pre_inputs); + + v_vec_values[12] = final_values[0]; + v_vec_values[13] = final_values[1]; + if values.f { + v_vec_values[14] = final_values[2]; + } + + for r in 0..values.round { + let s = SIGMA_VALUES[(r as usize) % 10]; + + let mut v_mid1_vec_values = v_vec_values.clone(); + let mut v_mid2_vec_values = v_vec_values.clone(); + let mut v_mid_va_bit_vec = Vec::new(); + let mut v_mid_vb_bit_vec = Vec::new(); + let mut v_mid_vc_bit_vec = Vec::new(); + let mut v_mid_vd_bit_vec = Vec::new(); + let mut v_xor_d_bit_vec = Vec::new(); + let mut v_xor_b_bit_vec = Vec::new(); + let mut b_bit_vec = Vec::new(); + let mut b_3bits_vec = Vec::new(); + + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (0, 4, 8, 12), + (m_vec_values[s[0]], m_vec_values[s[1]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (1, 5, 9, 13), + (m_vec_values[s[2]], m_vec_values[s[3]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (2, 6, 10, 14), + (m_vec_values[s[4]], m_vec_values[s[5]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid1_vec_values, &mut v_mid2_vec_values), + (3, 7, 11, 15), + (m_vec_values[s[6]], m_vec_values[s[7]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + + let mut v_mid3_vec_values = v_mid2_vec_values.clone(); + let mut v_mid4_vec_values = v_mid2_vec_values.clone(); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (0, 5, 10, 15), + (m_vec_values[s[8]], m_vec_values[s[9]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (1, 6, 11, 12), + (m_vec_values[s[10]], m_vec_values[s[11]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (2, 7, 8, 13), + (m_vec_values[s[12]], m_vec_values[s[13]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + g_wg( + (&mut v_mid3_vec_values, &mut v_mid4_vec_values), + (3, 4, 9, 14), + (m_vec_values[s[14]], m_vec_values[s[15]]), + (&mut v_mid_va_bit_vec, &mut v_mid_vb_bit_vec), + (&mut v_mid_vc_bit_vec, &mut v_mid_vd_bit_vec), + (&mut v_xor_d_bit_vec, &mut v_xor_b_bit_vec), + (&mut b_bit_vec, &mut b_3bits_vec), + ); + + let ginputs = GInput { + round: F::from(r as u64), + v_vec: v_vec_values.iter().map(|&v| F::from(v)).collect(), + h_vec: h_vec_values.iter().map(|&v| F::from(v)).collect(), + m_vec: m_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid1_vec: v_mid1_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid2_vec: v_mid2_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid3_vec: v_mid3_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid4_vec: v_mid4_vec_values.iter().map(|&v| F::from(v)).collect(), + v_mid_va_bit_vec, + v_mid_vb_bit_vec, + v_mid_vc_bit_vec, + v_mid_vd_bit_vec, + v_xor_d_bit_vec, + v_xor_b_bit_vec, + b_bit_vec, + b_3bits_vec, + }; + ctx.add(&blake2f_g_setup_vec[r as usize], ginputs); + v_vec_values = v_mid4_vec_values.clone(); + } + + let output_vec_values: Vec = h_vec_values + .iter() + .zip( + v_vec_values[0..8] + .iter() + .zip(v_vec_values[V_LEN / 2..V_LEN].iter()), + ) + .map(|(h, (v1, v2))| h ^ v1 ^ v2) + .collect(); + + let final_inputs = FinalInput { + round: F::from(values.round as u64), + v_vec: v_vec_values.iter().map(|&v| F::from(v)).collect(), + h_vec: h_vec_values.iter().map(|&v| F::from(v)).collect(), + output_vec: output_vec_values.iter().map(|&v| F::from(v)).collect(), + v_split_bit_vec: v_vec_values + .iter() + .map(|&v| split_value_4bits(v as u128, SPLIT_64BITS)) + .collect(), + h_split_bit_vec: h_vec_values + .iter() + .map(|&v| split_value_4bits(v as u128, SPLIT_64BITS)) + .collect(), + v_xor_split_bit_vec: v_vec_values[0..V_LEN / 2] + .iter() + .zip(v_vec_values[V_LEN / 2..V_LEN].iter()) + .map(|(&v1, &v2)| split_xor_value(v1, v2)) + .collect(), + final_split_bit_vec: output_vec_values + .iter() + .map(|&output| split_value_4bits(output as u128, SPLIT_64BITS)) + .collect(), + }; + ctx.add(&blake2f_final_step, final_inputs); + // ba80a53f981c4d0d, 6a2797b69f12f6e9, 4c212f14685ac4b7, 4b12bb6fdbffa2d1 + // 7d87c5392aab792d, c252d5de4533cc95, 18d38aa8dbf1925a,b92386edd4009923 + println!( + "output = {:?} \n {:?}", + u64_to_string(&output_vec_values[0..4].try_into().unwrap()), + u64_to_string(&output_vec_values[4..8].try_into().unwrap()) + ); + }) +} + +fn blake2f_super_circuit() -> SuperCircuit { + super_circuit::("blake2f", |ctx| { + let single_config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); + let (_, iv_table) = ctx.sub_circuit(single_config.clone(), blake2f_iv_table, IV_LEN); + let (_, bits_table) = ctx.sub_circuit( + single_config.clone(), + blake2f_4bits_table, + SPLIT_64BITS as usize, + ); + let (_, xor_4bits_table) = ctx.sub_circuit( + single_config, + blake2f_xor_4bits_table, + (SPLIT_64BITS * SPLIT_64BITS) as usize, + ); + + let maxwidth_config = config( + MaxWidthCellManager::new(250, true), + SimpleStepSelectorBuilder {}, + ); + + let params = CircuitParams { + iv_table, + bits_table, + xor_4bits_table, + }; + let (blake2f, _) = ctx.sub_circuit(maxwidth_config, blake2f_circuit, params); + + ctx.mapping(move |ctx, values| { + ctx.map(&blake2f, values); + }) + }) +} + +fn main() { + let super_circuit = blake2f_super_circuit::(); + let compiled = chiquitoSuperCircuit2Halo2(&super_circuit); + + // h[0] = hex"48c9bdf267e6096a 3ba7ca8485ae67bb 2bf894fe72f36e3c f1361d5f3af54fa5"; + // h[1] = hex"d182e6ad7f520e51 1f6c3e2b8c68059b 6bbd41fbabd9831f 79217e1319cde05b"; + let h0 = string_to_u64([ + "48c9bdf267e6096a", + "3ba7ca8485ae67bb", + "2bf894fe72f36e3c", + "f1361d5f3af54fa5", + ]); + let h1 = string_to_u64([ + "d182e6ad7f520e51", + "1f6c3e2b8c68059b", + "6bbd41fbabd9831f", + "79217e1319cde05b", + ]); + // m[0] = hex"6162630000000000 0000000000000000 0000000000000000 0000000000000000"; + // m[1] = hex"0000000000000000 0000000000000000 0000000000000000 0000000000000000"; + // m[2] = hex"0000000000000000 0000000000000000 0000000000000000 0000000000000000"; + // m[3] = hex"0000000000000000 0000000000000000 0000000000000000 0000000000000000"; + let m0 = string_to_u64([ + "6162630000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + let m1 = string_to_u64([ + "0000000000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + let m2 = string_to_u64([ + "0000000000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + let m3 = string_to_u64([ + "0000000000000000", + "0000000000000000", + "0000000000000000", + "0000000000000000", + ]); + + let values = InputValues { + round: 12, + + h_vec: [ + h0[0], // 0x6a09e667f2bdc948, + h0[1], // 0xbb67ae8584caa73b, + h0[2], // 0x3c6ef372fe94f82b, + h0[3], // 0xa54ff53a5f1d36f1, + h1[0], // 0x510e527fade682d1, + h1[1], // 0x9b05688c2b3e6c1f, + h1[2], // 0x1f83d9abfb41bd6b, + h1[3], // 0x5be0cd19137e2179, + ], // 8 * 64bits + + m_vec: [ + m0[0], // 0x636261, + m0[1], // 0, + m0[2], // 0, + m0[3], // 0, + m1[0], // 0, + m1[1], // 0, + m1[2], // 0, + m1[3], // 0, + m2[0], // 0, + m2[1], // 0, + m2[2], // 0, + m2[3], // 0, + m3[0], // 0, + m3[1], // 0, + m3[2], // 0, + m3[3], // 0, + ], // 16 * 64bits + t0: 3, // 64bits + t1: 0, // 64bits + f: true, // 8bits + }; + + let circuit = + ChiquitoHalo2SuperCircuit::new(compiled, super_circuit.get_mapping().generate(values)); + + let prover = MockProver::run(9, &circuit, Vec::new()).unwrap(); + let result = prover.verify(); + + println!("result = {:#?}", result); + + if let Err(failures) = &result { + for failure in failures.iter() { + println!("{}", failure); + } + } +} diff --git a/examples/factorial.rs b/examples/factorial.rs index c927ab49..e47702bf 100644 --- a/examples/factorial.rs +++ b/examples/factorial.rs @@ -136,7 +136,7 @@ fn main() { let prover = MockProver::::run(10, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); @@ -167,7 +167,7 @@ fn main() { // same as halo2 boilerplate above let prover_plaf = MockProver::::run(8, &plaf_circuit, Vec::new()).unwrap(); - let result_plaf = prover_plaf.verify_par(); + let result_plaf = prover_plaf.verify(); println!("result = {:#?}", result_plaf); diff --git a/examples/fibo_with_padding.rs b/examples/fibo_with_padding.rs index 403f2ae2..6a4fc481 100644 --- a/examples/fibo_with_padding.rs +++ b/examples/fibo_with_padding.rs @@ -206,7 +206,7 @@ fn main() { let prover = MockProver::::run(7, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("{:#?}", result); @@ -237,7 +237,7 @@ fn main() { // same as halo2 boilerplate above let prover_plaf = MockProver::::run(8, &plaf_circuit, plaf_circuit.instance()).unwrap(); - let result_plaf = prover_plaf.verify_par(); + let result_plaf = prover_plaf.verify(); println!("result = {:#?}", result_plaf); diff --git a/examples/fibonacci.py b/examples/fibonacci.py index 4ffa7b31..0c7aedfa 100644 --- a/examples/fibonacci.py +++ b/examples/fibonacci.py @@ -84,3 +84,5 @@ def trace(self, n): ) # 2^k specifies the number of PLONKish table rows in Halo2 another_fibo_witness = fibo.gen_witness(4) fibo.halo2_mock_prover(another_fibo_witness, k=7) + +fibo.to_pil(fibo_witness, "FiboCircuit") diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 4021cdda..b01d00e3 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -3,30 +3,39 @@ use std::hash::Hash; use chiquito::{ field::Field, frontend::dsl::circuit, // main function for constructing an AST circuit - plonkish::backend::halo2::{chiquito2Halo2, ChiquitoHalo2Circuit}, /* compiles to + plonkish::{ + backend::{ + halo2::{chiquito2Halo2, ChiquitoHalo2Circuit}, + hyperplonk::ChiquitoHyperPlonkCircuit, + }, + compiler::{ + cell_manager::SingleRowCellManager, // input for constructing the compiler + compile, // input for constructing the compiler + config, + step_selector::SimpleStepSelectorBuilder, + }, + ir::{assignments::AssignmentGenerator, Circuit}, + }, /* compiles to * Chiquito Halo2 * backend, * which can be * integrated into * Halo2 * circuit */ - plonkish::compiler::{ - cell_manager::SingleRowCellManager, // input for constructing the compiler - compile, // input for constructing the compiler - config, - step_selector::SimpleStepSelectorBuilder, - }, - plonkish::ir::{assignments::AssignmentGenerator, Circuit}, // compiled circuit type poly::ToField, + sbpir::SBPIR, }; -use halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; +use halo2_proofs::dev::MockProver; // the main circuit function: returns the compiled IR of a Chiquito circuit // Generic types F, (), (u64, 64) stand for: // 1. type that implements a field trait // 2. empty trace arguments, i.e. (), because there are no external inputs to the Chiquito circuit // 3. two witness generation arguments both of u64 type, i.e. (u64, u64) -fn fibo_circuit + Hash>() -> (Circuit, Option>) { + +type FiboReturn = (Circuit, Option>, SBPIR); + +fn fibo_circuit + Hash>() -> FiboReturn { // PLONKish table for the Fibonacci circuit: // | a | b | c | // | 1 | 1 | 2 | @@ -73,7 +82,7 @@ fn fibo_circuit + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option + Hash>() -> (Circuit, Option(); + let (chiquito, wit_gen, _) = fibo_circuit::(); let compiled = chiquito2Halo2(chiquito); let circuit = ChiquitoHalo2Circuit::new(compiled, wit_gen.map(|g| g.generate(()))); let prover = MockProver::::run(7, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("{:#?}", result); @@ -137,7 +148,7 @@ fn main() { use polyexen::plaf::{backends::halo2::PlafH2Circuit, WitnessDisplayCSV}; // get Chiquito ir - let (circuit, wit_gen) = fibo_circuit::(); + let (circuit, wit_gen, _) = fibo_circuit::(); // get Plaf let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, 8, false); let wit = plaf_wit_gen.generate(wit_gen.map(|v| v.generate(()))); @@ -153,7 +164,7 @@ fn main() { // same as halo2 boilerplate above let prover_plaf = MockProver::::run(8, &plaf_circuit, plaf_circuit.instance()).unwrap(); - let result_plaf = prover_plaf.verify_par(); + let result_plaf = prover_plaf.verify(); println!("result = {:#?}", result_plaf); @@ -162,4 +173,34 @@ fn main() { println!("{}", failure); } } + + // hyperplonk boilerplate + use hyperplonk_benchmark::proof_system::{bench_plonkish_backend, System}; + use plonkish_backend::{ + backend, + halo2_curves::bn256::{Bn256, Fr}, + pcs::{multilinear, univariate}, + }; + // get Chiquito ir + let (circuit, assignment_generator, _) = fibo_circuit::(); + // get assignments + let assignments = assignment_generator.unwrap().generate(()); + // get hyperplonk circuit + let mut hyperplonk_circuit = ChiquitoHyperPlonkCircuit::new(4, circuit); + hyperplonk_circuit.set_assignment(assignments); + + type GeminiKzg = multilinear::Gemini>; + type HyperPlonk = backend::hyperplonk::HyperPlonk; + bench_plonkish_backend::(System::HyperPlonk, 4, &hyperplonk_circuit); + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquito2Pil; + + let (_, wit_gen, circuit) = fibo_circuit::(); + let pil = chiquito2Pil( + circuit, + Some(wit_gen.unwrap().generate_trace_witness(())), + String::from("FiboCircuit"), + ); + print!("{}", pil); } diff --git a/examples/keccak.rs b/examples/keccak.rs new file mode 100644 index 00000000..bc06c93b --- /dev/null +++ b/examples/keccak.rs @@ -0,0 +1,2365 @@ +use chiquito::{ + frontend::dsl::{lb::LookupTable, super_circuit, CircuitContext, StepTypeWGHandler}, + plonkish::{ + backend::halo2::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, + compiler::{ + cell_manager::{MaxWidthCellManager, SingleRowCellManager}, + config, + step_selector::SimpleStepSelectorBuilder, + }, + ir::sc::SuperCircuit, + }, + poly::ToExpr, + sbpir::query::Queriable, +}; +use std::{hash::Hash, ops::Neg}; + +use halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, group::ff::PrimeField}, +}; + +use std::{ + fs::File, + io::{self, Write}, +}; + +const BIT_COUNT: u64 = 3; +const PART_SIZE: u64 = 5; +const NUM_BYTES_PER_WORD: u64 = 8; +const NUM_BITS_PER_BYTE: u64 = 8; +const NUM_WORDS_TO_ABSORB: u64 = 17; +const RATE: u64 = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +const NUM_BITS_PER_WORD: u64 = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; +const NUM_PER_WORD: u64 = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE / 2; +const RATE_IN_BITS: u64 = RATE * NUM_BITS_PER_BYTE; +const NUM_ROUNDS: u64 = 24; +const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); + +const NUM_PER_WORD_BATCH3: u64 = 22; +const NUM_PER_WORD_BATCH4: u64 = 16; + +const SQUEEZE_VECTOR_NUM: u64 = 4; +const SQUEEZE_SPLIT_NUM: u64 = 16; + +const PART_SIZE_SQURE: u64 = PART_SIZE * PART_SIZE; + +pub const ROUND_CST: [u64; NUM_ROUNDS as usize + 1] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + 0x0000000000000000, +]; + +pub const XOR_VALUE_BATCH2: [u64; 36] = [ + 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0x8, + 0x9, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x8, 0x9, +]; + +pub const XOR_VALUE_BATCH3: [u64; 64] = [ + 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x40, 0x41, + 0x40, 0x41, 0x48, 0x49, 0x48, 0x49, 0x40, 0x41, 0x40, 0x41, 0x48, 0x49, 0x48, 0x49, 0x0, 0x1, + 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x0, 0x1, 0x0, 0x1, 0x8, 0x9, 0x8, 0x9, 0x40, 0x41, 0x40, 0x41, + 0x48, 0x49, 0x48, 0x49, 0x40, 0x41, 0x40, 0x41, 0x48, 0x49, 0x48, 0x49, +]; + +pub const XOR_VALUE_BATCH4: [u64; 81] = [ + 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, 0x1, 0x0, 0x40, 0x41, 0x40, 0x48, 0x49, 0x48, 0x40, 0x41, + 0x40, 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, 0x1, 0x0, 0x200, 0x201, 0x200, 0x208, 0x209, 0x208, + 0x200, 0x201, 0x200, 0x240, 0x241, 0x240, 0x248, 0x249, 0x248, 0x240, 0x241, 0x240, 0x200, + 0x201, 0x200, 0x208, 0x209, 0x208, 0x200, 0x201, 0x200, 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, 0x1, + 0x0, 0x40, 0x41, 0x40, 0x48, 0x49, 0x48, 0x40, 0x41, 0x40, 0x0, 0x1, 0x0, 0x8, 0x9, 0x8, 0x0, + 0x1, 0x0, +]; + +pub const CHI_VALUE: [u64; 125] = [ + 0x0, 0x1, 0x1, 0x0, 0x0, 0x8, 0x9, 0x9, 0x8, 0x8, 0x8, 0x9, 0x9, 0x8, 0x8, 0x0, 0x1, 0x1, 0x0, + 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x40, 0x41, 0x41, 0x40, 0x40, 0x48, 0x49, 0x49, 0x48, 0x48, 0x48, + 0x49, 0x49, 0x48, 0x48, 0x40, 0x41, 0x41, 0x40, 0x40, 0x40, 0x41, 0x41, 0x40, 0x40, 0x40, 0x41, + 0x41, 0x40, 0x40, 0x48, 0x49, 0x49, 0x48, 0x48, 0x48, 0x49, 0x49, 0x48, 0x48, 0x40, 0x41, 0x41, + 0x40, 0x40, 0x40, 0x41, 0x41, 0x40, 0x40, 0x0, 0x1, 0x1, 0x0, 0x0, 0x8, 0x9, 0x9, 0x8, 0x8, + 0x8, 0x9, 0x9, 0x8, 0x8, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, + 0x0, 0x8, 0x9, 0x9, 0x8, 0x8, 0x8, 0x9, 0x9, 0x8, 0x8, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x1, 0x1, + 0x0, 0x0, +]; + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word +fn pack(bits: &[u8]) -> F { + pack_with_base(bits, BIT_SIZE) +} + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word with the +/// specified bit base +fn pack_with_base(bits: &[u8], base: usize) -> F { + // \sum 8^i * bit_i + let base = F::from(base as u64); + bits.iter() + .rev() + .fold(F::ZERO, |acc, &bit| acc * base + F::from(bit as u64)) +} + +fn pack_u64(value: u64) -> F { + pack( + &((0..NUM_BITS_PER_WORD) + .map(|i| ((value >> i) & 1) as u8) + .collect::>()), + ) +} + +/// Calculates a ^ b with a and b field elements +fn field_xor>(a: F, b: F) -> F { + let mut bytes = [0u8; 32]; + for (idx, (a, b)) in a + .to_repr() + .as_ref() + .iter() + .zip(b.to_repr().as_ref().iter()) + .enumerate() + { + bytes[idx] = *a ^ *b; + } + F::from_repr(bytes).unwrap() +} + +fn convert_bytes_to_bits(bytes: Vec) -> Vec { + bytes + .iter() + .map(|&byte| { + let mut byte = byte; + (0..8) + .map(|_| { + let b = byte % 2; + byte /= 2; + b + }) + .collect() + }) + .collect::>>() + .concat() +} + +fn convert_field_to_vec_bits(value: F) -> Vec { + let mut v_vec = Vec::new(); + let mut left = 0; + for (idx, &v1) in value.to_repr().as_ref().iter().enumerate() { + if idx % 3 == 0 { + v_vec.push(v1 % 8); + v_vec.push((v1 / 8) % 8); + left = v1 / 64; + } else if idx % 3 == 1 { + v_vec.push((v1 % 2) * 4 + left); + v_vec.push((v1 / 2) % 8); + v_vec.push((v1 / 16) % 8); + left = v1 / 128; + } else { + v_vec.push((v1 % 4) * 2 + left); + v_vec.push((v1 / 4) % 8); + v_vec.push(v1 / 32); + left = 0; + } + } + v_vec[0..64].to_vec() +} + +fn convert_bits_to_f>(value_vec: &[u8]) -> F { + assert_eq!(value_vec.len(), NUM_BITS_PER_WORD as usize); + let mut sum_value_arr: Vec = (0..24) + .map(|t| { + if t % 3 == 0 { + value_vec[(t / 3) * 8] + + value_vec[(t / 3) * 8 + 1] * 8 + + (value_vec[(t / 3) * 8 + 2] % 4) * 64 + } else if t % 3 == 1 { + value_vec[(t / 3) * 8 + 2] / 4 + + value_vec[(t / 3) * 8 + 3] * 2 + + (value_vec[(t / 3) * 8 + 4]) * 16 + + ((value_vec[(t / 3) * 8 + 5]) % 2) * 128 + } else { + value_vec[(t / 3) * 8 + 5] / 2 + + value_vec[(t / 3) * 8 + 6] * 4 + + (value_vec[(t / 3) * 8 + 7]) * 32 + } + }) + .collect(); + while sum_value_arr.len() < 32 { + sum_value_arr.push(0); + } + F::from_repr(sum_value_arr.try_into().unwrap()).unwrap() +} + +fn eval_keccak_f_to_bit_vec4>(value1: F, value2: F) -> Vec<(F, F)> { + let v1_vec = convert_field_to_vec_bits(value1); + let v2_vec = convert_field_to_vec_bits(value2); + assert_eq!(v1_vec.len(), NUM_BITS_PER_WORD as usize); + assert_eq!(v2_vec.len(), NUM_BITS_PER_WORD as usize); + (0..NUM_PER_WORD_BATCH4 as usize) + .map(|i| { + ( + F::from_u128( + v1_vec[4 * i] as u128 + + v1_vec[4 * i + 1] as u128 * 8 + + v1_vec[4 * i + 2] as u128 * 64 + + v1_vec[4 * i + 3] as u128 * 512, + ), + F::from_u128( + v2_vec[4 * i] as u128 + + v2_vec[4 * i + 1] as u128 * 8 + + v2_vec[4 * i + 2] as u128 * 64 + + v2_vec[4 * i + 3] as u128 * 512, + ), + ) + }) + .collect() +} + +fn keccak_xor_table_batch2( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_xor_row: Queriable = ctx.fixed("xor row(batch 2)"); + let lookup_xor_c: Queriable = ctx.fixed("xor value(batch 2)"); + + let constants_value = XOR_VALUE_BATCH2; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign(i, lookup_xor_row, F::from(((i / 6) * 8 + i % 6) as u64)); + ctx.assign(i, lookup_xor_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_c)) +} + +fn keccak_xor_table_batch3( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_xor_row: Queriable = ctx.fixed("xor row(batch 3)"); + let lookup_xor_c: Queriable = ctx.fixed("xor value(batch 3)"); + + let constants_value = XOR_VALUE_BATCH3; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign( + i, + lookup_xor_row, + F::from(((i / 16) * 64 + (i % 16) / 4 * 8 + i % 4) as u64), + ); + ctx.assign(i, lookup_xor_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_c)) +} + +fn keccak_xor_table_batch4( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_xor_row: Queriable = ctx.fixed("xor row(batch 4)"); + let lookup_xor_c: Queriable = ctx.fixed("xor value(batch 4)"); + + let constants_value = XOR_VALUE_BATCH4; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign( + i, + lookup_xor_row, + F::from((i / 27 * 512 + (i % 27) / 9 * 64 + (i % 9) / 3 * 8 + i % 3) as u64), + ); + ctx.assign(i, lookup_xor_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_xor_row).add(lookup_xor_c)) +} + +fn keccak_chi_table( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_chi_row: Queriable = ctx.fixed("chi row"); + let lookup_chi_c: Queriable = ctx.fixed("chi value"); + + let constants_value = CHI_VALUE; + assert_eq!(lens, constants_value.len()); + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign( + i, + lookup_chi_row, + F::from(((i / 25) * 64 + (i % 25) / 5 * 8 + i % 5) as u64), + ); + ctx.assign(i, lookup_chi_c, F::from(value)); + } + }); + + ctx.new_table(table().add(lookup_chi_row).add(lookup_chi_c)) +} + +fn keccak_pack_table( + ctx: &mut CircuitContext, + _: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_pack_row: Queriable = ctx.fixed("pack row"); + let lookup_pack_c: Queriable = ctx.fixed("pack value"); + ctx.pragma_num_steps((SQUEEZE_SPLIT_NUM * SQUEEZE_SPLIT_NUM) as usize); + ctx.fixed_gen(move |ctx| { + for i in 0..SQUEEZE_SPLIT_NUM as usize { + let index = (i / 8) * 512 + (i % 8) / 4 * 64 + (i % 4) / 2 * 8 + i % 2; + for j in 0..SQUEEZE_SPLIT_NUM as usize { + let index_j = (j / 8) * 512 + (j % 8) / 4 * 64 + (j % 4) / 2 * 8 + j % 2; + ctx.assign( + i * SQUEEZE_SPLIT_NUM as usize + j, + lookup_pack_row, + F::from((index * 4096 + index_j) as u64), + ); + ctx.assign( + i * SQUEEZE_SPLIT_NUM as usize + j, + lookup_pack_c, + F::from((i * 16 + j) as u64), + ); + } + } + }); + ctx.new_table(table().add(lookup_pack_row).add(lookup_pack_c)) +} + +fn keccak_round_constants_table( + ctx: &mut CircuitContext, + lens: usize, +) -> LookupTable { + use chiquito::frontend::dsl::cb::*; + + let lookup_constant_row: Queriable = ctx.fixed("constant row"); + let lookup_constant_c: Queriable = ctx.fixed("constant value"); + + let constants_value = ROUND_CST; + ctx.pragma_num_steps(lens); + ctx.fixed_gen(move |ctx| { + for (i, &value) in constants_value.iter().enumerate().take(lens) { + ctx.assign(i, lookup_constant_row, F::from(i as u64)); + ctx.assign(i, lookup_constant_c, pack_u64::(value)); + } + }); + ctx.new_table(table().add(lookup_constant_row).add(lookup_constant_c)) +} + +struct PreValues { + s_vec: Vec, + absorb_rows: Vec, + round_value: F, + absorb_split_vec: Vec>, + absorb_split_input_vec: Vec>, + split_values: Vec>, + is_padding_vec: Vec>, + input_len: F, + data_rlc_vec: Vec>, + data_rlc: F, + input_acc: F, + padded: F, +} + +#[derive(Clone)] +struct SqueezeValues { + s_new_vec: Vec, + squeeze_split_vec: Vec>, + squeeze_split_output_vec: Vec>, + hash_rlc: F, +} + +#[derive(Clone)] +struct OneRoundValues { + round: F, + next_round: F, + round_cst: F, + input_len: F, + input_acc: F, + + s_vec: Vec, + s_new_vec: Vec, + + theta_split_vec: Vec>, + theta_split_xor_vec: Vec>, + theta_sum_split_vec: Vec>, + theta_sum_split_xor_vec: Vec>, + + rho_bit_0: Vec, + rho_bit_1: Vec, + + chi_split_value_vec: Vec>, + + final_sum_split_vec: Vec, + final_xor_split_vec: Vec, + + svalues: SqueezeValues, + data_rlc: F, + padded: F, +} + +fn eval_keccak_f_one_round + Eq + Hash>( + round: u64, + cst: u64, + s_vec: Vec, + input_len: u64, + data_rlc: F, + input_acc: F, + padded: F, +) -> OneRoundValues { + let mut s_new_vec = Vec::new(); + let mut theta_split_vec = Vec::new(); + let mut theta_split_xor_vec = Vec::new(); + let mut theta_sum_split_xor_value_vec = Vec::new(); + let mut theta_sum_split_xor_move_value_vec = Vec::new(); + let mut theta_sum_split_vec = Vec::new(); + let mut theta_sum_split_xor_vec = Vec::new(); + let mut rho_pi_s_new_vec = vec![F::ZERO; PART_SIZE_SQURE as usize]; + let mut rho_bit_0 = vec![F::ZERO; 15]; + let mut rho_bit_1 = vec![F::ZERO; 15]; + let mut chi_sum_value_vec = Vec::new(); + let mut chi_sum_split_value_vec = Vec::new(); + let mut chi_split_value_vec = Vec::new(); + let mut final_sum_split_vec = Vec::new(); + let mut final_xor_split_vec = Vec::new(); + + let mut t_vec = vec![0; PART_SIZE_SQURE as usize]; + { + let mut i: usize = 1; + let mut j: usize = 0; + for t in 0..PART_SIZE_SQURE as usize { + if t == 0 { + i = 0; + j = 0 + } else if t == 1 { + i = 1; + j = 0; + } else { + let m = j; + j = (2 * i + 3 * j) % PART_SIZE as usize; + i = m; + } + t_vec[i * PART_SIZE as usize + j] = t; + } + } + + for i in 0..PART_SIZE as usize { + let sum = s_vec[i * PART_SIZE as usize] + + s_vec[i * PART_SIZE as usize + 1] + + s_vec[i * PART_SIZE as usize + 2] + + s_vec[i * PART_SIZE as usize + 3] + + s_vec[i * PART_SIZE as usize + 4]; + let sum_bits = convert_field_to_vec_bits(sum); + + let xor: F = field_xor( + field_xor( + field_xor( + field_xor( + s_vec[i * PART_SIZE as usize], + s_vec[i * PART_SIZE as usize + 1], + ), + s_vec[i * PART_SIZE as usize + 2], + ), + s_vec[i * PART_SIZE as usize + 3], + ), + s_vec[i * PART_SIZE as usize + 4], + ); + let xor_bits = convert_field_to_vec_bits(xor); + let mut xor_bits_move = xor_bits.clone(); + xor_bits_move.rotate_right(1); + let xor_rot: F = convert_bits_to_f(&xor_bits_move); + + let mut sum_split = Vec::new(); + let mut sum_split_xor = Vec::new(); + for k in 0..sum_bits.len() / 2 { + if k == sum_bits.len() / 2 - 1 { + sum_split.push(F::from_u128(sum_bits[2 * k] as u128)); + sum_split.push(F::from_u128(sum_bits[2 * k + 1] as u128)); + sum_split_xor.push(F::from_u128(xor_bits[2 * k] as u128)); + sum_split_xor.push(F::from_u128(xor_bits[2 * k + 1] as u128)); + } else { + sum_split.push( + F::from_u128(sum_bits[2 * k] as u128) + + F::from_u128(sum_bits[2 * k + 1] as u128) * F::from_u128(8), + ); + sum_split_xor.push( + F::from_u128(xor_bits[2 * k] as u128) + + F::from_u128(xor_bits[2 * k + 1] as u128) * F::from_u128(8), + ); + } + } + + theta_split_vec.push(sum_split); + theta_split_xor_vec.push(sum_split_xor); + theta_sum_split_xor_value_vec.push(xor); + theta_sum_split_xor_move_value_vec.push(xor_rot); + } + + let mut rho_index = 0; + for i in 0..PART_SIZE as usize { + let xor = theta_sum_split_xor_value_vec[(i + PART_SIZE as usize - 1) % PART_SIZE as usize]; + let xor_rot = theta_sum_split_xor_move_value_vec[(i + 1) % PART_SIZE as usize]; + for j in 0..PART_SIZE as usize { + let v = ((t_vec[i * PART_SIZE as usize + j] + 1) * t_vec[i * PART_SIZE as usize + j] + / 2) + % NUM_BITS_PER_WORD as usize; + let st = s_vec[i * PART_SIZE as usize + j] + xor + xor_rot; + let st_xor = field_xor(field_xor(s_vec[i * PART_SIZE as usize + j], xor), xor_rot); + let mut st_split = Vec::new(); + let mut st_split_xor = Vec::new(); + let mut st_bit_vec = convert_field_to_vec_bits(st); + let mut st_bit_xor_vec = convert_field_to_vec_bits(st_xor); + + // rho + // a[x][y][z] = a[x][y][z-(t+1)(t+2)/2] + if v % 3 == 1 { + rho_bit_0[rho_index] = + F::from(st_bit_vec[1] as u64) * F::from_u128(8) + F::from(st_bit_vec[0] as u64); + rho_bit_1[rho_index] = F::from(st_bit_vec[NUM_BITS_PER_WORD as usize - 1] as u64); + rho_index += 1 + } else if v % 3 == 2 { + rho_bit_0[rho_index] = F::from(st_bit_vec[0] as u64); + rho_bit_1[rho_index] = F::from(st_bit_vec[NUM_BITS_PER_WORD as usize - 1] as u64) + * F::from_u128(8) + + F::from(st_bit_vec[NUM_BITS_PER_WORD as usize - 2] as u64); + rho_index += 1 + } + + st_bit_vec.rotate_right(v); + st_bit_xor_vec.rotate_right(v); + + for i in 0..st_bit_vec.len() / 3 { + st_split.push( + F::from_u128(st_bit_vec[3 * i] as u128) + + F::from_u128(st_bit_vec[3 * i + 1] as u128) * F::from_u128(8) + + F::from_u128(st_bit_vec[3 * i + 2] as u128) * F::from_u128(64), + ); + st_split_xor.push( + F::from_u128(st_bit_xor_vec[3 * i] as u128) + + F::from_u128(st_bit_xor_vec[3 * i + 1] as u128) * F::from_u128(8) + + F::from_u128(st_bit_xor_vec[3 * i + 2] as u128) * F::from_u128(64), + ); + } + st_split.push(F::from_u128( + st_bit_vec[NUM_BITS_PER_WORD as usize - 1] as u128, + )); + st_split_xor.push(F::from_u128( + st_bit_xor_vec[NUM_BITS_PER_WORD as usize - 1] as u128, + )); + + theta_sum_split_vec.push(st_split); + theta_sum_split_xor_vec.push(st_split_xor); + + // pi + // a[y][2x + 3y] = a[x][y] + rho_pi_s_new_vec[j * PART_SIZE as usize + ((2 * i + 3 * j) % PART_SIZE as usize)] = + convert_bits_to_f(&st_bit_xor_vec); + } + } + + // chi + // a[x] = a[x] ^ (~a[x+1] & a[x+2]) + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + let a_vec = convert_field_to_vec_bits(rho_pi_s_new_vec[i * PART_SIZE as usize + j]); + let b_vec = + convert_field_to_vec_bits(rho_pi_s_new_vec[((i + 1) % 5) * PART_SIZE as usize + j]); + let c_vec = + convert_field_to_vec_bits(rho_pi_s_new_vec[((i + 2) % 5) * PART_SIZE as usize + j]); + let sum_vec: Vec = a_vec + .iter() + .zip(b_vec.iter().zip(c_vec.iter())) + .map(|(&a, (&b, &c))| 3 + b - 2 * a - c) + .collect(); + let sum: F = convert_bits_to_f(&sum_vec); + + let split_chi_value: Vec = sum_vec + .iter() + .map(|&v| if v == 1 || v == 2 { 1 } else { 0 }) + .collect(); + let sum_chi = convert_bits_to_f(&split_chi_value); + + let sum_split_vec: Vec = (0..NUM_PER_WORD_BATCH3 as usize) + .map(|i| { + if i == NUM_PER_WORD_BATCH3 as usize - 1 { + F::from_u128(sum_vec[3 * i] as u128) + } else { + F::from_u128( + sum_vec[3 * i] as u128 + + sum_vec[3 * i + 1] as u128 * 8 + + sum_vec[3 * i + 2] as u128 * 64, + ) + } + }) + .collect(); + let chi_split_vec: Vec = (0..NUM_PER_WORD_BATCH3 as usize) + .map(|i| { + if i == NUM_PER_WORD_BATCH3 as usize - 1 { + F::from_u128(split_chi_value[3 * i] as u128) + } else { + F::from_u128( + split_chi_value[3 * i] as u128 + + split_chi_value[3 * i + 1] as u128 * 8 + + split_chi_value[3 * i + 2] as u128 * 64, + ) + } + }) + .collect(); + + chi_sum_value_vec.push(sum); + s_new_vec.push(sum_chi); + chi_sum_split_value_vec.push(sum_split_vec); + chi_split_value_vec.push(chi_split_vec); + } + } + + let s_iota_vec = convert_field_to_vec_bits(s_new_vec[0]); + let cst_vec = convert_field_to_vec_bits(pack_u64::(cst)); + let split_xor_vec: Vec = s_iota_vec + .iter() + .zip(cst_vec.iter()) + .map(|(v1, v2)| v1 ^ v2) + .collect(); + let xor_rows: Vec<(F, F)> = s_iota_vec + .iter() + .zip(cst_vec.iter()) + .map(|(v1, v2)| { + ( + F::from_u128((v1 + v2) as u128), + F::from_u128((v1 ^ v2) as u128), + ) + }) + .collect(); + + for i in 0..NUM_PER_WORD_BATCH4 as usize { + final_sum_split_vec.push( + xor_rows[4 * i].0 + + xor_rows[4 * i + 1].0 * F::from_u128(8) + + xor_rows[4 * i + 2].0 * F::from_u128(64) + + xor_rows[4 * i + 3].0 * F::from_u128(512), + ); + final_xor_split_vec.push( + xor_rows[4 * i].1 + + xor_rows[4 * i + 1].1 * F::from_u128(8) + + xor_rows[4 * i + 2].1 * F::from_u128(64) + + xor_rows[4 * i + 3].1 * F::from_u128(512), + ); + } + + s_new_vec[0] = convert_bits_to_f(&split_xor_vec); + + let svalues = SqueezeValues { + s_new_vec: Vec::new(), + squeeze_split_vec: Vec::new(), + squeeze_split_output_vec: Vec::new(), + hash_rlc: F::ZERO, + }; + + let next_round = if round < NUM_ROUNDS - 1 { round + 1 } else { 0 }; + + OneRoundValues { + round: F::from(round), + round_cst: pack_u64::(cst), + input_len: F::from(input_len), + next_round: F::from(next_round), + + s_vec, + s_new_vec, + + theta_split_vec, + theta_split_xor_vec, + theta_sum_split_vec, + theta_sum_split_xor_vec, + + rho_bit_0, + rho_bit_1, + + chi_split_value_vec, + + final_sum_split_vec, + final_xor_split_vec, + + svalues, + data_rlc, + input_acc, + padded, + } +} + +fn keccak_circuit + Eq + Hash>( + ctx: &mut CircuitContext, + param: CircuitParams, +) { + use chiquito::frontend::dsl::cb::*; + + let s_vec: Vec> = (0..PART_SIZE_SQURE) + .map(|i| ctx.forward(&format!("s[{}][{}]", i / PART_SIZE, i % PART_SIZE))) + .collect(); + + let round = ctx.forward("round"); + let data_rlc = ctx.forward("data_rlc"); + + let input_len = ctx.forward("input_len"); + let input_acc = ctx.forward("input_acc"); + + let padded = ctx.forward("padded"); + + let keccak_first_step = ctx.step_type_def("keccak first step", |ctx| { + let s_vec = s_vec.clone(); + let setup_s_vec = s_vec.clone(); + + let absorb_vec: Vec> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| ctx.internal(&format!("absorb_{}", i))) + .collect(); + let setup_absorb_vec = absorb_vec.clone(); + + let absorb_split_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_vec = absorb_split_vec.clone(); + + let absorb_split_input_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_input_vec = absorb_split_input_vec.clone(); + + let is_padding_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("is_padding_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_is_padding_vec = is_padding_vec.clone(); + + let data_rlc_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("is_padding_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_data_rlc_vec = data_rlc_vec.clone(); + + ctx.setup(move |ctx| { + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + ctx.constr(eq(setup_s_vec[i * PART_SIZE as usize + j], 0)); + if j * PART_SIZE as usize + i < NUM_WORDS_TO_ABSORB as usize { + // xor + // 000 xor 000/001 -> 000 + 000/001 + ctx.transition(eq( + setup_s_vec[i * PART_SIZE as usize + j] + + setup_absorb_vec[j * PART_SIZE as usize + i], + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + + let mut tmp_absorb_split_sum_vec = setup_absorb_split_vec + [j * PART_SIZE as usize + i][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * 1; + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + tmp_absorb_split_sum_vec = tmp_absorb_split_sum_vec * 4096 * 4096 + + setup_absorb_split_vec[j * PART_SIZE as usize + i] + [SQUEEZE_SPLIT_NUM as usize / 2 - k - 1]; + } + ctx.constr(eq( + setup_absorb_vec[j * PART_SIZE as usize + i], + tmp_absorb_split_sum_vec, + )); + + for k in 0..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.add_lookup( + param + .pack_table + .apply(setup_absorb_split_vec[j * PART_SIZE as usize + i][k]) + .apply( + setup_absorb_split_input_vec[j * PART_SIZE as usize + i][k], + ), + ); + ctx.constr(eq( + (setup_is_padding_vec[j * PART_SIZE as usize + i][k] - 1) + * setup_is_padding_vec[j * PART_SIZE as usize + i][k], + 0, + )); + } + } else { + ctx.transition(eq( + setup_s_vec[i * PART_SIZE as usize + j], + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } + } + ctx.constr(eq(data_rlc, 0)); + ctx.transition(eq( + setup_data_rlc_vec[NUM_WORDS_TO_ABSORB as usize - 1] + [SQUEEZE_SPLIT_NUM as usize / 2 - 1], + data_rlc.next(), + )); + let mut acc_value = 0.expr() * 1; + for i in 0..NUM_WORDS_TO_ABSORB as usize { + if i == 0 { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (data_rlc * 256 + setup_absorb_split_input_vec[i][0]) + * (1.expr() - setup_is_padding_vec[i][0]) + + data_rlc * setup_is_padding_vec[i][0], + )); + } else { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] * 256 + + setup_absorb_split_input_vec[i][0]) + * (setup_is_padding_vec[i][0] - 1).neg() + + setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * setup_is_padding_vec[i][0], + )); + } + + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.constr(eq( + setup_data_rlc_vec[i][k], + (setup_data_rlc_vec[i][k - 1] * 256 + setup_absorb_split_input_vec[i][k]) + * (setup_is_padding_vec[i][k] - 1).neg() + + setup_data_rlc_vec[i][k - 1] * setup_is_padding_vec[i][k], + )); + } + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][0]); + if i == 0 { + ctx.constr(eq(setup_is_padding_vec[i][0], 0)); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * ((setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * (setup_absorb_split_vec[i][0] - 1), + 0, + )); + } + for k in 1..8 { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * ((setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + // the last one + if k == 7 && i == NUM_WORDS_TO_ABSORB as usize - 1 { + // the padding length is equal than 1 byte + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 2097153), + 0, + )); + // the padding length is bigger than 1 byte + ctx.constr(eq( + setup_is_padding_vec[i][k - 1] + * (setup_absorb_split_vec[i][k] - 2097152), + 0, + )); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + // the first padding byte = 1, other = 0 + ctx.constr(eq( + setup_is_padding_vec[i][k] + * (setup_is_padding_vec[i][k] + - setup_is_padding_vec[i][k - 1] + - setup_absorb_split_vec[i][k]), + 0, + )); + } + + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][k]); + } + } + ctx.constr(eq( + (input_len - input_acc - acc_value.clone()) + * setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + 0, + )); + ctx.transition(eq(input_acc + acc_value, input_acc.next())); + + ctx.constr(eq(round, 0)); + ctx.transition(eq(round, round.next())); + ctx.transition(eq(input_len, input_len.next())); + ctx.constr(eq(padded, 0)); + ctx.transition(eq( + setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + padded.next(), + )); + }); + + ctx.wg::, _>(move |ctx, values| { + for (q, v) in absorb_vec.iter().zip(values.absorb_rows.iter()) { + ctx.assign(*q, *v) + } + + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + ctx.assign(s_vec[i * PART_SIZE as usize + j], F::ZERO); + } + } + for (q_vec, v_vec) in absorb_split_vec.iter().zip(values.absorb_split_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in absorb_split_input_vec + .iter() + .zip(values.absorb_split_input_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in is_padding_vec.iter().zip(values.is_padding_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in data_rlc_vec.iter().zip(values.data_rlc_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + ctx.assign(round, values.round_value); + ctx.assign(input_len, values.input_len); + ctx.assign(data_rlc, values.data_rlc); + ctx.assign(input_acc, values.input_acc); + ctx.assign(padded, values.padded); + }) + }); + + let keccak_pre_step = ctx.step_type_def("keccak pre step", |ctx| { + let s_vec = s_vec.clone(); + let setup_s_vec = s_vec.clone(); + + let absorb_vec: Vec> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| ctx.internal(&format!("absorb_{}", i))) + .collect(); + let setup_absorb_vec = absorb_vec.clone(); + + let absorb_split_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_vec = absorb_split_vec.clone(); + + let absorb_split_input_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("absorb_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_absorb_split_input_vec = absorb_split_input_vec.clone(); + + let sum_split_value_vec: Vec> = (0..PART_SIZE_SQURE) + .map(|i| ctx.internal(&format!("sum_split_value_{}", i))) + .collect(); + let setup_sum_split_value_vec = sum_split_value_vec.clone(); + + let split_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..NUM_PER_WORD_BATCH4) + .map(|j| ctx.internal(&format!("split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_split_vec = split_vec.clone(); + + let split_xor_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..NUM_PER_WORD_BATCH4) + .map(|j| ctx.internal(&format!("split_xor_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_split_xor_vec = split_xor_vec.clone(); + + let is_padding_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("is_padding_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_is_padding_vec = is_padding_vec.clone(); + + let data_rlc_vec: Vec>> = (0..NUM_WORDS_TO_ABSORB) + .map(|i| { + (0..8) + .map(|j| ctx.internal(&format!("data_rlc_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_data_rlc_vec = data_rlc_vec.clone(); + + ctx.setup(move |ctx| { + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + if j * PART_SIZE as usize + i < NUM_WORDS_TO_ABSORB as usize { + // xor + ctx.constr(eq( + setup_s_vec[i * PART_SIZE as usize + j] + + setup_absorb_vec[j * PART_SIZE as usize + i], + setup_sum_split_value_vec[i * PART_SIZE as usize + j], + )); + + let mut tmp_absorb_split_sum_vec = setup_absorb_split_vec + [j * PART_SIZE as usize + i][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * 1; + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + tmp_absorb_split_sum_vec = tmp_absorb_split_sum_vec * 4096 * 4096 + + setup_absorb_split_vec[j * PART_SIZE as usize + i] + [SQUEEZE_SPLIT_NUM as usize / 2 - k - 1]; + } + ctx.constr(eq( + setup_absorb_vec[j * PART_SIZE as usize + i], + tmp_absorb_split_sum_vec, + )); + for k in 0..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.add_lookup( + param + .pack_table + .apply(setup_absorb_split_vec[j * PART_SIZE as usize + i][k]) + .apply( + setup_absorb_split_input_vec[j * PART_SIZE as usize + i][k], + ), + ); + } + + for k in 0..NUM_PER_WORD_BATCH4 as usize { + ctx.add_lookup( + param + .xor_table_batch4 + .apply(setup_split_vec[j * PART_SIZE as usize + i][k]) + .apply(setup_split_xor_vec[j * PART_SIZE as usize + i][k]), + ); + } + } else { + ctx.transition(eq( + setup_s_vec[i * PART_SIZE as usize + j], + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } + } + + ctx.transition(eq( + setup_data_rlc_vec[NUM_WORDS_TO_ABSORB as usize - 1] + [SQUEEZE_SPLIT_NUM as usize / 2 - 1], + data_rlc.next(), + )); + + let mut acc_value = 0.expr() * 1; + for i in 0..NUM_WORDS_TO_ABSORB as usize { + if i == 0 { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (data_rlc * 256 + setup_absorb_split_input_vec[i][0]) + * (setup_is_padding_vec[i][0] - 1).neg() + + data_rlc * setup_is_padding_vec[i][0], + )); + } else { + // data_rlc_vec[0][0] = 0 * 256 + absorb_split_input_vec[0][0]; + ctx.constr(eq( + setup_data_rlc_vec[i][0], + (setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] * 256 + + setup_absorb_split_input_vec[i][0]) + * (setup_is_padding_vec[i][0] - 1).neg() + + setup_data_rlc_vec[i - 1][SQUEEZE_SPLIT_NUM as usize / 2 - 1] + * setup_is_padding_vec[i][0], + )); + } + for k in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.constr(eq( + setup_data_rlc_vec[i][k], + (setup_data_rlc_vec[i][k - 1] * 256 + setup_absorb_split_input_vec[i][k]) + * (setup_is_padding_vec[i][k] - 1).neg() + + setup_data_rlc_vec[i][k - 1] * setup_is_padding_vec[i][k], + )); + } + + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][0]); + if i == 0 { + ctx.constr(eq( + setup_is_padding_vec[i][0] * (setup_is_padding_vec[i][0] - 1), + 0, + )); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * ((setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][0] - setup_is_padding_vec[i - 1][7]) + * (setup_absorb_split_vec[i][0] - 1), + 0, + )); + } + for k in 1..8 { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * ((setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) - 1), + 0, + )); + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + + if k == 7 && i == NUM_WORDS_TO_ABSORB as usize - 1 { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 2097153), + 0, + )); + ctx.constr(eq( + setup_is_padding_vec[i][k - 1] + * (setup_absorb_split_vec[i][k] - 2097152), + 0, + )); + } else { + ctx.constr(eq( + (setup_is_padding_vec[i][k] - setup_is_padding_vec[i][k - 1]) + * (setup_absorb_split_vec[i][k] - 1), + 0, + )); + ctx.constr(eq( + setup_is_padding_vec[i][k] + * (setup_is_padding_vec[i][k] + - setup_is_padding_vec[i][k - 1] + - setup_absorb_split_vec[i][k]), + 0, + )); + } + acc_value = acc_value + (1.expr() - setup_is_padding_vec[i][k]); + } + } + + for s in 0..NUM_WORDS_TO_ABSORB as usize { + let mut sum_split_vec = setup_split_vec[s][NUM_PER_WORD_BATCH4 as usize - 1] * 1; + let mut sum_split_xor_vec = + setup_split_xor_vec[s][NUM_PER_WORD_BATCH4 as usize - 1] * 1; + for (&value, &xor_value) in setup_split_vec[s] + .iter() + .rev() + .zip(setup_split_xor_vec[s].iter().rev()) + .skip(1) + { + sum_split_vec = sum_split_vec * 64 * 64 + value; + sum_split_xor_vec = sum_split_xor_vec * 64 * 64 + xor_value; + } + ctx.constr(eq( + sum_split_vec, + setup_sum_split_value_vec + [(s % PART_SIZE as usize) * PART_SIZE as usize + s / PART_SIZE as usize], + )); + ctx.transition(eq( + sum_split_xor_vec, + setup_s_vec + [(s % PART_SIZE as usize) * PART_SIZE as usize + s / PART_SIZE as usize] + .next(), + )); + } + + ctx.constr(eq( + (input_len - input_acc - acc_value.clone()) + * setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + 0, + )); + ctx.transition(eq(input_acc + acc_value, input_acc.next())); + + ctx.transition(eq(round, round.next())); + ctx.transition(eq(input_len, input_len.next())); + + ctx.constr(eq(padded, 0)); + ctx.transition(eq( + setup_is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7], + padded.next(), + )); + }); + + ctx.wg::, _>(move |ctx, values| { + ctx.assign(round, F::ZERO); + for (q, v) in absorb_vec.iter().zip(values.absorb_rows.iter()) { + ctx.assign(*q, *v) + } + + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + ctx.assign(s_vec[i * PART_SIZE as usize + j], F::ZERO); + } + } + for (q_vec, v_vec) in absorb_split_vec.iter().zip(values.absorb_split_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in absorb_split_input_vec + .iter() + .zip(values.absorb_split_input_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in is_padding_vec.iter().zip(values.is_padding_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for (q_vec, v_vec) in data_rlc_vec.iter().zip(values.data_rlc_vec.iter()) { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + if j * PART_SIZE as usize + i < NUM_WORDS_TO_ABSORB as usize { + ctx.assign( + sum_split_value_vec[i * PART_SIZE as usize + j], + values.s_vec[i * PART_SIZE as usize + j] + + values.absorb_rows[j * PART_SIZE as usize + i], + ); + ctx.assign( + absorb_vec[j * PART_SIZE as usize + i], + values.absorb_rows[j * PART_SIZE as usize + i], + ); + } else { + ctx.assign( + sum_split_value_vec[i * PART_SIZE as usize + j], + values.s_vec[i * PART_SIZE as usize + j], + ); + } + ctx.assign( + s_vec[i * PART_SIZE as usize + j], + values.s_vec[i * PART_SIZE as usize + j], + ); + } + } + + for i in 0..NUM_WORDS_TO_ABSORB as usize { + for j in 0..NUM_PER_WORD_BATCH4 as usize { + ctx.assign(split_vec[i][j], values.split_values[i][j].0); + ctx.assign(split_xor_vec[i][j], values.split_values[i][j].1); + } + } + ctx.assign(input_len, values.input_len); + ctx.assign(data_rlc, values.data_rlc); + ctx.assign(input_acc, values.input_acc); + ctx.assign(padded, values.padded); + }) + }); + + let keccak_one_round_step_vec: Vec, _>> = (0..2) + .map(|last| { + ctx.step_type_def("keccak one round", |ctx| { + let s_vec = s_vec.clone(); + let setup_s_vec = s_vec.clone(); + + let theta_split_vec: Vec>> = (0..PART_SIZE) + .map(|i| { + (0..NUM_PER_WORD + 1) + .map(|j| ctx.internal(&format!("theta_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_split_vec = theta_split_vec.clone(); + + let theta_split_xor_vec: Vec>> = (0..PART_SIZE) + .map(|i| { + (0..NUM_PER_WORD + 1) + .map(|j| ctx.internal(&format!("theta_split_xor_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_split_xor_vec = theta_split_xor_vec.clone(); + + let theta_sum_split_vec: Vec>> = (0..PART_SIZE_SQURE) + .map(|i| { + (0..NUM_PER_WORD_BATCH3) + .map(|j| ctx.internal(&format!("theta_sum_split_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_sum_split_vec = theta_sum_split_vec.clone(); + + let theta_sum_split_xor_vec: Vec>> = (0..PART_SIZE_SQURE) + .map(|i| { + (0..NUM_PER_WORD_BATCH3) + .map(|j| ctx.internal(&format!("theta_sum_split_xor_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_theta_sum_split_xor_vec = theta_sum_split_xor_vec.clone(); + + let rho_bit_0: Vec> = (0..15) + .map(|i| ctx.internal(&format!("rho_bit0_{}", i))) + .collect(); + let setup_rho_bit_0 = rho_bit_0.clone(); + + let rho_bit_1: Vec> = (0..15) + .map(|i| ctx.internal(&format!("rho_bit1_{}", i))) + .collect(); + let setup_rho_bit_1 = rho_bit_1.clone(); + + let chi_split_value_vec: Vec>> = (0..PART_SIZE_SQURE) + .map(|i| { + (0..NUM_PER_WORD_BATCH3) + .map(|j| ctx.internal(&format!("chi_split_value_{}_{}", i, j))) + .collect() + }) + .collect(); + let setup_chi_split_value_vec: Vec>> = chi_split_value_vec.clone(); + + let final_xor_split_vec: Vec> = (0..NUM_PER_WORD_BATCH4) + .map(|i| ctx.internal(&format!("final_xor_split_{}", i))) + .collect(); + let setup_final_xor_split_vec = final_xor_split_vec.clone(); + + let final_sum_split_vec: Vec> = (0..NUM_PER_WORD_BATCH4) + .map(|i| ctx.internal(&format!("final_sum_split_{}", i))) + .collect(); + let setup_final_sum_split_vec = final_sum_split_vec.clone(); + let round_cst: Queriable = ctx.internal("round constant"); + + let mut hash_rlc = data_rlc; + let mut next_round = round; + if last == 0 { + next_round = ctx.internal("next_round"); + } else { + hash_rlc = ctx.internal("hash_rlc"); + } + + let mut squeeze_split_vec: Vec>> = Vec::new(); + if last == 1 { + squeeze_split_vec = (0..SQUEEZE_VECTOR_NUM) + .map(|i| { + (0..SQUEEZE_SPLIT_NUM / 2) + .map(|j| ctx.internal(&format!("squeeze_split_vec_{}_{}", i, j))) + .collect() + }) + .collect(); + } + let setup_squeeze_split_vec = squeeze_split_vec.clone(); + + let mut squeeze_split_output_vec: Vec>> = Vec::new(); + if last == 1 { + squeeze_split_output_vec = (0..SQUEEZE_VECTOR_NUM) + .map(|i| { + (0..SQUEEZE_SPLIT_NUM / 2) + .map(|j| { + ctx.internal(&format!("squeeze_split_output_vec_{}_{}", i, j)) + }) + .collect() + }) + .collect(); + } + let setup_squeeze_split_output_vec = squeeze_split_output_vec.clone(); + + let mut s_new_vec: Vec> = Vec::new(); + if last == 1 { + s_new_vec = (0..PART_SIZE_SQURE) + .map(|i| { + ctx.internal(&format!("s_new[{}][{}]", i / PART_SIZE, i % PART_SIZE)) + }) + .collect(); + } + let setup_s_new_vec = s_new_vec.clone(); + + ctx.setup(move |ctx| { + let mut t_vec = vec![0; PART_SIZE_SQURE as usize]; + { + let mut i: usize = 1; + let mut j: usize = 0; + for t in 0..PART_SIZE_SQURE { + if t == 0 { + i = 0; + j = 0 + } else if t == 1 { + i = 1; + j = 0; + } else { + let m = j; + j = (2 * i + 3 * j) % PART_SIZE as usize; + i = m; + } + t_vec[i * PART_SIZE as usize + j] = t; + } + } + + // Theta + let mut tmp_theta_sum_split_xor_vec = Vec::new(); + let mut tmp_theta_sum_move_split_xor_vec = Vec::new(); + for s in 0..PART_SIZE as usize { + // 1. \sum_y' a[x][y'][z] + // 2. xor(sum) + let mut sum_split_vec = setup_theta_split_vec[s][NUM_PER_WORD as usize] * 8 + + setup_theta_split_vec[s][NUM_PER_WORD as usize - 1]; + let mut sum_split_xor_vec = + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize] * 8 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - 1]; + let mut sum_split_xor_move_value_vec = + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - 1] * 1; + for k in 1..NUM_PER_WORD as usize { + sum_split_vec = sum_split_vec * 64 + + setup_theta_split_vec[s][NUM_PER_WORD as usize - k - 1]; + sum_split_xor_vec = sum_split_xor_vec * 64 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - k - 1]; + sum_split_xor_move_value_vec = sum_split_xor_move_value_vec * 64 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize - k - 1]; + } + sum_split_xor_move_value_vec = sum_split_xor_move_value_vec * 8 + + setup_theta_split_xor_vec[s][NUM_PER_WORD as usize]; + + for k in 0..NUM_PER_WORD as usize { + ctx.add_lookup( + param + .xor_table + .apply(setup_theta_split_vec[s][k]) + .apply(setup_theta_split_xor_vec[s][k]), + ); + } + + ctx.constr(eq( + setup_s_vec[s * PART_SIZE as usize] + + setup_s_vec[s * PART_SIZE as usize + 1] + + setup_s_vec[s * PART_SIZE as usize + 2] + + setup_s_vec[s * PART_SIZE as usize + 3] + + setup_s_vec[s * PART_SIZE as usize + 4], + sum_split_vec, + )); + + tmp_theta_sum_split_xor_vec.push(sum_split_xor_vec); + tmp_theta_sum_move_split_xor_vec.push(sum_split_xor_move_value_vec); + } + + let mut rho_index = 0; + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + // Theta + // 3. a[x][y][z] = a[x][y][z] + xor(\sum_y' a[x-1][y'][z]) + xor(\sum + // a[x+1][y'][z-1]) 4. a'[x][y][z'+(t+1)(t+2)/2] = + // xor(a[x][y][z'+(t+1)(t+2)/2]) rho + // a[x][y][z'] = a[x][y][z'] + let v = ((t_vec[i * PART_SIZE as usize + j] + 1) + * t_vec[i * PART_SIZE as usize + j] + / 2) + % NUM_BITS_PER_WORD; + + for k in 0..NUM_PER_WORD_BATCH3 as usize { + ctx.add_lookup( + param + .xor_table_batch3 + .apply( + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [k], + ) + .apply( + setup_theta_sum_split_xor_vec + [i * PART_SIZE as usize + j][k], + ), + ); + } + + let mut tmp_theta_sum_split; + if v % 3 == 0 { + let st = (v / 3) as usize; + if st != 0 { + tmp_theta_sum_split = setup_theta_sum_split_vec + [i * PART_SIZE as usize + j][st - 1] + * 1; + for k in 1..st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [st - k - 1]; + } + tmp_theta_sum_split = tmp_theta_sum_split * 8 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1]; + for k in 1..NUM_PER_WORD_BATCH3 as usize - st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } else { + tmp_theta_sum_split = setup_theta_sum_split_vec + [i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1] + * 1; + for k in 1..NUM_PER_WORD_BATCH3 as usize { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } + } else if v % 3 == 1 { + let st = ((v - 1) / 3) as usize; + tmp_theta_sum_split = setup_rho_bit_1[rho_index] * 1; + for k in 0..st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [st - k - 1]; + } + for k in 0..NUM_PER_WORD_BATCH3 as usize - st - 1 { + if k == 0 { + tmp_theta_sum_split = tmp_theta_sum_split * 8 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1]; + } else { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } + tmp_theta_sum_split = + tmp_theta_sum_split * 64 + setup_rho_bit_0[rho_index]; + ctx.constr(eq( + setup_rho_bit_0[rho_index] * 8 + setup_rho_bit_1[rho_index], + setup_theta_sum_split_vec[i * PART_SIZE as usize + j][st], + )); + rho_index += 1; + } else { + let st = ((v - 2) / 3) as usize; + tmp_theta_sum_split = setup_rho_bit_1[rho_index] * 1; + for k in 0..st { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [st - k - 1]; + } + for k in 0..NUM_PER_WORD_BATCH3 as usize - st - 1 { + if k == 0 { + tmp_theta_sum_split = tmp_theta_sum_split * 8 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1]; + } else { + tmp_theta_sum_split = tmp_theta_sum_split * 512 + + setup_theta_sum_split_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + } + tmp_theta_sum_split = + tmp_theta_sum_split * 8 + setup_rho_bit_0[rho_index]; + ctx.constr(eq( + setup_rho_bit_0[rho_index] * 64 + setup_rho_bit_1[rho_index], + setup_theta_sum_split_vec[i * PART_SIZE as usize + j][st], + )); + rho_index += 1; + } + + ctx.constr(eq( + tmp_theta_sum_split, + setup_s_vec[i * PART_SIZE as usize + j] + + tmp_theta_sum_split_xor_vec + [(i + PART_SIZE as usize - 1) % PART_SIZE as usize] + .clone() + + tmp_theta_sum_move_split_xor_vec + [(i + 1) % PART_SIZE as usize] + .clone(), + )); + } + } + + let mut tmp_pi_sum_split_xor_vec = setup_theta_sum_split_xor_vec.clone(); + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + tmp_pi_sum_split_xor_vec + [j * PART_SIZE as usize + ((2 * i + 3 * j) % PART_SIZE as usize)] = + setup_theta_sum_split_xor_vec[i * PART_SIZE as usize + j].clone(); + } + } + + // chi + // a[x] = a[x] ^ (~a[x+1] & a[x+2]) + // chi(3 - 2a[x] + a[x+1] - a[x+2]) + ctx.add_lookup(param.constants_table.apply(round).apply(round_cst)); + for i in 0..PART_SIZE as usize { + for j in 0..PART_SIZE as usize { + for k in 0..NUM_PER_WORD_BATCH3 as usize { + ctx.add_lookup( + param + .chi_table + .apply( + tmp_pi_sum_split_xor_vec[((i + 1) + % PART_SIZE as usize) + * PART_SIZE as usize + + j][k] + - tmp_pi_sum_split_xor_vec + [i * PART_SIZE as usize + j][k] + - tmp_pi_sum_split_xor_vec + [i * PART_SIZE as usize + j][k] + - tmp_pi_sum_split_xor_vec[((i + 2) + % PART_SIZE as usize) + * PART_SIZE as usize + + j][k] + + 219, + ) + .apply( + setup_chi_split_value_vec[i * PART_SIZE as usize + j] + [k], + ), + ); + } + + let mut tmp_sum_split_chi_vec = setup_chi_split_value_vec + [i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - 1] + * 1; + for k in 1..NUM_PER_WORD_BATCH3 as usize { + tmp_sum_split_chi_vec = tmp_sum_split_chi_vec * 512 + + setup_chi_split_value_vec[i * PART_SIZE as usize + j] + [NUM_PER_WORD_BATCH3 as usize - k - 1]; + } + + if i != 0 || j != 0 { + if last == 1 { + ctx.transition(eq( + tmp_sum_split_chi_vec, + setup_s_new_vec[i * PART_SIZE as usize + j], + )); + } else { + ctx.transition(eq( + tmp_sum_split_chi_vec, + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } else { + let mut tmp_sum_s_split_vec = + setup_final_sum_split_vec[NUM_PER_WORD_BATCH4 as usize - 1] * 1; + let mut tmp_sum_s_split_xor_vec = + setup_final_xor_split_vec[NUM_PER_WORD_BATCH4 as usize - 1] * 1; + ctx.add_lookup( + param + .xor_table_batch4 + .apply( + setup_final_sum_split_vec + [NUM_PER_WORD_BATCH4 as usize - 1], + ) + .apply( + setup_final_xor_split_vec + [NUM_PER_WORD_BATCH4 as usize - 1], + ), + ); + for (&value, &xor_value) in setup_final_sum_split_vec + .iter() + .zip(setup_final_xor_split_vec.iter()) + .rev() + .skip(1) + { + tmp_sum_s_split_vec = tmp_sum_s_split_vec * 64 * 64 + value; + tmp_sum_s_split_xor_vec = + tmp_sum_s_split_xor_vec * 64 * 64 + xor_value; + ctx.add_lookup( + param.xor_table_batch4.apply(value).apply(xor_value), + ); + } + + ctx.constr(eq( + tmp_sum_s_split_vec, + tmp_sum_split_chi_vec + round_cst, + )); + if last == 1 { + ctx.transition(eq( + tmp_sum_s_split_xor_vec, + setup_s_new_vec[i * PART_SIZE as usize + j], + )); + } else { + ctx.transition(eq( + tmp_sum_s_split_xor_vec, + setup_s_vec[i * PART_SIZE as usize + j].next(), + )); + } + } + } + } + + if last == 1 { + for i in 0..SQUEEZE_VECTOR_NUM as usize { + let mut tmp_squeeze_split_sum = + setup_squeeze_split_vec[i][SQUEEZE_SPLIT_NUM as usize / 2 - 1] * 1; + for j in 1..SQUEEZE_SPLIT_NUM as usize / 2 { + tmp_squeeze_split_sum = tmp_squeeze_split_sum * 4096 * 4096 + + setup_squeeze_split_vec[i] + [SQUEEZE_SPLIT_NUM as usize / 2 - j - 1]; + } + ctx.constr(eq( + tmp_squeeze_split_sum, + setup_s_new_vec[i * PART_SIZE as usize], + )); + for j in 0..SQUEEZE_SPLIT_NUM as usize / 2 { + ctx.add_lookup( + param + .pack_table + .apply(setup_squeeze_split_vec[i][j]) + .apply(setup_squeeze_split_output_vec[i][j]), + ); + } + // hash_rlc + let mut tmp_hash_rlc_value = setup_squeeze_split_output_vec[0][0] * 1; + + for (i, values) in setup_squeeze_split_output_vec.iter().enumerate() { + for (j, &value) in values + .iter() + .enumerate() + .take(SQUEEZE_SPLIT_NUM as usize / 2) + { + if i != 0 || j != 0 { + tmp_hash_rlc_value = tmp_hash_rlc_value * 256 + value; + } + } + } + } + } + + if last == 1 { + ctx.constr(eq(round + 1, NUM_ROUNDS)); + ctx.constr(eq(input_len, input_acc)); + ctx.constr(eq(padded, 1)); + } else { + ctx.constr(eq((round + 1 - next_round) * next_round, 0)); + // xor((round + 1 = next_round), (round + 1 = NUM_ROUNDS)) + // (round + 1 - next_round) / NUM_ROUNDS = 0, round < 23; 1, round = 23 + // (round + 1 - NUM_ROUNDS) / (NUM_ROUNDS - next_round) = 1,round < 23; 0, + // round = 23 (round + 1 - next_round) / NUM_ROUNDS + // + (round + 1 - NUM_ROUNDS) / (NUM_ROUNDS - next_round) + // - 2 * ((round + 1 - next_round) / NUM_ROUNDS) * ((round + 1 - NUM_ROUNDS) + // / (NUM_ROUNDS - next_round)) = 1 + // (round + 1 - next_round) * (NUM_ROUNDS - next_round) + (round + 1 - + // NUM_ROUNDS) * NUM_ROUNDS + 2 * (round + 1 - + // next_round) * (round + 1 - NUM_ROUNDS) = NUM_ROUNDS * (NUM_ROUNDS - + // next_round) + ctx.constr(eq( + (round + 1 - next_round) * (next_round - NUM_ROUNDS) + + (round + 1 - NUM_ROUNDS) * NUM_ROUNDS + - (round + 1 - next_round) * (round + 1 - NUM_ROUNDS) * 2, + (next_round - NUM_ROUNDS) * NUM_ROUNDS, + )); + ctx.transition(eq(next_round, round.next())); + ctx.transition(eq(input_len, input_len.next())); + ctx.transition(eq(data_rlc, data_rlc.next())); + ctx.transition(eq(padded, padded.next())); + } + }); + + ctx.wg::, _>(move |ctx, values| { + ctx.assign(round, values.round); + ctx.assign(round_cst, values.round_cst); + if last == 0 { + ctx.assign(next_round, values.next_round); + } + for (q, v) in s_vec.iter().zip(values.s_vec.iter()) { + ctx.assign(*q, *v) + } + for (q_vec, v_vec) in theta_split_vec.iter().zip(values.theta_split_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in theta_split_xor_vec + .iter() + .zip(values.theta_split_xor_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in theta_sum_split_vec + .iter() + .zip(values.theta_sum_split_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in theta_sum_split_xor_vec + .iter() + .zip(values.theta_sum_split_xor_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q, v) in rho_bit_0.iter().zip(values.rho_bit_0.iter()) { + ctx.assign(*q, *v) + } + for (q, v) in rho_bit_1.iter().zip(values.rho_bit_1.iter()) { + ctx.assign(*q, *v) + } + for (q_vec, v_vec) in chi_split_value_vec + .iter() + .zip(values.chi_split_value_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q, v) in final_sum_split_vec + .iter() + .zip(values.final_sum_split_vec.iter()) + { + ctx.assign(*q, *v) + } + for (q, v) in final_xor_split_vec + .iter() + .zip(values.final_xor_split_vec.iter()) + { + ctx.assign(*q, *v) + } + if last == 1 { + for (q, v) in s_new_vec.iter().zip(values.svalues.s_new_vec.iter()) { + ctx.assign(*q, *v) + } + for (q_vec, v_vec) in squeeze_split_vec + .iter() + .zip(values.svalues.squeeze_split_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + for (q_vec, v_vec) in squeeze_split_output_vec + .iter() + .zip(values.svalues.squeeze_split_output_vec.iter()) + { + for (q, v) in q_vec.iter().zip(v_vec.iter()) { + ctx.assign(*q, *v) + } + } + ctx.assign(hash_rlc, values.svalues.hash_rlc); + } + ctx.assign(input_len, values.input_len); + ctx.assign(data_rlc, values.data_rlc); + ctx.assign(input_acc, values.input_acc); + ctx.assign(padded, values.padded); + }) + }) + }) + .collect(); + + ctx.pragma_first_step(&keccak_first_step); // keccak_pre_step + ctx.pragma_last_step(&keccak_one_round_step_vec[1]); // keccak_squeeze_step + ctx.pragma_num_steps(param.step_num); + + ctx.trace(move |ctx, params| { + let input_num = params.bytes.len(); + let mut bits = convert_bytes_to_bits(params.bytes); + println!("intput bits(without padding) = {:?}", bits); + // padding + bits.push(1); + while (bits.len() + 1) % RATE_IN_BITS as usize != 0 { + bits.push(0); + } + bits.push(1); + println!("intput bits(with padding) = {:?}", bits); + + let mut s_new = [F::ZERO; PART_SIZE_SQURE as usize]; + + // chunks + let chunks = bits.chunks(RATE_IN_BITS as usize); + let chunks_len = chunks.len(); + let mut data_rlc_value = F::ZERO; + let mut input_acc = F::ZERO; + // absorb + for (k, chunk) in chunks.enumerate() { + let s: Vec = s_new.to_vec(); + let absorbs: Vec = (0..PART_SIZE_SQURE as usize) + .map(|idx| { + let i = idx % PART_SIZE as usize; + let j = idx / PART_SIZE as usize; + let mut absorb = F::ZERO; + if idx < NUM_WORDS_TO_ABSORB as usize { + absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); + s_new[i * PART_SIZE as usize + j] = + field_xor(s[i * PART_SIZE as usize + j], absorb); + } else { + s_new[i * PART_SIZE as usize + j] = s[i * PART_SIZE as usize + j]; + } + absorb + }) + .collect(); + + let absorb_split_vec: Vec> = (0..NUM_WORDS_TO_ABSORB as usize) + .map(|idx| { + let bits = chunk[idx * 64..(idx + 1) * 64].to_vec(); + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + F::from( + bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 8 + + bits[k * 8 + 2] as u64 * 64 + + bits[k * 8 + 3] as u64 * 512 + + bits[k * 8 + 4] as u64 * 4096 + + bits[k * 8 + 5] as u64 * 8 * 4096 + + bits[k * 8 + 6] as u64 * 64 * 4096 + + bits[k * 8 + 7] as u64 * 512 * 4096, + ) + }) + .collect() + }) + .collect(); + + let absorb_split_input_vec: Vec> = (0..NUM_WORDS_TO_ABSORB as usize) + .map(|idx| { + let bits = chunk[idx * 64..(idx + 1) * 64].to_vec(); + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + F::from( + bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 2 + + bits[k * 8 + 2] as u64 * 4 + + bits[k * 8 + 3] as u64 * 8 + + bits[k * 8 + 4] as u64 * 16 + + bits[k * 8 + 5] as u64 * 32 + + bits[k * 8 + 6] as u64 * 64 + + bits[k * 8 + 7] as u64 * 128, + ) + }) + .collect() + }) + .collect(); + + let mut is_padding_vec = vec![vec![F::ONE; 8]; NUM_WORDS_TO_ABSORB as usize]; + is_padding_vec = is_padding_vec + .iter() + .enumerate() + .map(|(i, is_paddings)| { + is_paddings + .iter() + .enumerate() + .take(8) + .map(|(j, &is_padding)| { + if input_num > k * 8 * NUM_WORDS_TO_ABSORB as usize + i * 8 + j { + F::ZERO + } else { + is_padding + } + }) + .collect() + }) + .collect(); + + let mut padded = F::ZERO; + if k == 0 { + let data_rlc = data_rlc_value; + let data_rlc_vec: Vec> = absorb_split_input_vec + .iter() + .zip(is_padding_vec.iter()) + .map(|(v1_vec, v2_vec)| { + v1_vec + .iter() + .zip(v2_vec.iter()) + .map(|(&v1, &v2)| { + if v2 == F::ZERO { + data_rlc_value = data_rlc_value * F::from(256) + v1 + } + data_rlc_value + }) + .collect() + }) + .collect(); + + let values = PreValues { + s_vec: s, + absorb_rows: absorbs[0..NUM_WORDS_TO_ABSORB as usize].to_vec(), + round_value: F::ZERO, + absorb_split_vec, + absorb_split_input_vec, + split_values: Vec::new(), + is_padding_vec: is_padding_vec.clone(), + input_len: F::from(input_num as u64), + data_rlc_vec, + data_rlc, + input_acc, + padded, + }; + ctx.add(&keccak_first_step, values); + } else { + let data_rlc = data_rlc_value; + let split_values = (0..NUM_WORDS_TO_ABSORB as usize) + .map(|t| { + let i = t % PART_SIZE as usize; + let j = t / PART_SIZE as usize; + let v = i * PART_SIZE as usize + j; + eval_keccak_f_to_bit_vec4::( + s[v] + absorbs[(v % PART_SIZE as usize) * PART_SIZE as usize + + (v / PART_SIZE as usize)], + s_new[v], + ) + }) + .collect(); + + let data_rlc_vec: Vec> = absorb_split_input_vec + .iter() + .zip(is_padding_vec.iter()) + .map(|(v1_vec, v2_vec)| { + v1_vec + .iter() + .zip(v2_vec.iter()) + .map(|(&v1, &v2)| { + if v2 == F::ZERO { + data_rlc_value = data_rlc_value * F::from(256) + v1 + } + data_rlc_value + }) + .collect() + }) + .collect(); + let values = PreValues { + s_vec: s, + absorb_rows: absorbs[0..NUM_WORDS_TO_ABSORB as usize].to_vec(), + split_values, + absorb_split_vec, + absorb_split_input_vec, + round_value: F::ZERO, + is_padding_vec: is_padding_vec.clone(), + input_len: F::from(input_num as u64), + data_rlc_vec, + data_rlc, + input_acc, + padded, + }; + ctx.add(&keccak_pre_step, values); + } + padded = is_padding_vec[NUM_WORDS_TO_ABSORB as usize - 1][7]; + + input_acc = is_padding_vec.iter().fold(input_acc, |acc, is_paddings| { + let v = is_paddings + .iter() + .fold(F::ZERO, |acc, is_padding| acc + (F::ONE - is_padding)); + acc + v + }); + + for (round, &cst) in ROUND_CST.iter().enumerate().take(NUM_ROUNDS as usize) { + let mut values = eval_keccak_f_one_round( + round as u64, + cst, + s_new.to_vec(), + input_num as u64, + data_rlc_value, + input_acc, + padded, + ); + s_new = values.s_new_vec.clone().try_into().unwrap(); + + if k != chunks_len - 1 || round != NUM_ROUNDS as usize - 1 { + ctx.add(&keccak_one_round_step_vec[0], values.clone()); + } else { + // squeezing + let mut squeeze_split_vec: Vec> = Vec::new(); + let mut squeeze_split_output_vec: Vec> = Vec::new(); + for i in 0..4 { + let bits = convert_field_to_vec_bits(s_new[(i * PART_SIZE) as usize]); + + squeeze_split_vec.push( + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + let value = bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 8 + + bits[k * 8 + 2] as u64 * 64 + + bits[k * 8 + 3] as u64 * 512 + + bits[k * 8 + 4] as u64 * 4096 + + bits[k * 8 + 5] as u64 * 8 * 4096 + + bits[k * 8 + 6] as u64 * 64 * 4096 + + bits[k * 8 + 7] as u64 * 512 * 4096; + F::from(value) + }) + .collect(), + ); + squeeze_split_output_vec.push( + (0..SQUEEZE_SPLIT_NUM as usize / 2) + .map(|k| { + let value = bits[k * 8] as u64 + + bits[k * 8 + 1] as u64 * 2 + + bits[k * 8 + 2] as u64 * 4 + + bits[k * 8 + 3] as u64 * 8 + + bits[k * 8 + 4] as u64 * 16 + + bits[k * 8 + 5] as u64 * 32 + + bits[k * 8 + 6] as u64 * 64 + + bits[k * 8 + 7] as u64 * 128; + F::from(value) + }) + .collect(), + ); + } + let mut hash_rlc = F::ZERO; + for squeeze_split_output in squeeze_split_output_vec.iter().take(4) { + for output in squeeze_split_output + .iter() + .take(SQUEEZE_SPLIT_NUM as usize / 2) + { + hash_rlc = hash_rlc * F::from(256) + output; + } + } + values.svalues = SqueezeValues { + s_new_vec: s_new.to_vec(), + squeeze_split_vec, + squeeze_split_output_vec, + hash_rlc, + }; + ctx.add(&keccak_one_round_step_vec[1], values); + } + } + } + + let output2: Vec> = (0..4) + .map(|k| { + pack_with_base::( + &convert_field_to_vec_bits(s_new[(k * PART_SIZE) as usize]), + 2, + ) + .to_repr() + .into_iter() + .take(8) + .collect::>() + .to_vec() + }) + .collect(); + println!("output = {:x?}", output2.concat()); + }); +} + +#[derive(Default)] +struct KeccakCircuit { + // pub bits: Vec, + pub bytes: Vec, +} + +struct CircuitParams { + pub constants_table: LookupTable, + pub xor_table: LookupTable, + pub xor_table_batch3: LookupTable, + pub xor_table_batch4: LookupTable, + pub chi_table: LookupTable, + pub pack_table: LookupTable, + pub step_num: usize, +} + +fn keccak_super_circuit + Eq + Hash>( + input_len: usize, +) -> SuperCircuit { + super_circuit::("keccak", |ctx| { + let in_n = (input_len * 8 + 1 + RATE_IN_BITS as usize) / RATE_IN_BITS as usize; + let step_num = in_n * (1 + NUM_ROUNDS as usize); + + let single_config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); + // config(SingleRowCellManager {}, LogNSelectorBuilder {}); + + let (_, constants_table) = ctx.sub_circuit( + single_config.clone(), + keccak_round_constants_table, + NUM_ROUNDS as usize + 1, + ); + let (_, xor_table) = ctx.sub_circuit(single_config.clone(), keccak_xor_table_batch2, 36); + let (_, xor_table_batch3) = + ctx.sub_circuit(single_config.clone(), keccak_xor_table_batch3, 64); + let (_, xor_table_batch4) = + ctx.sub_circuit(single_config.clone(), keccak_xor_table_batch4, 81); + let (_, chi_table) = ctx.sub_circuit(single_config.clone(), keccak_chi_table, 125); + let (_, pack_table) = ctx.sub_circuit(single_config, keccak_pack_table, 0); + + let params = CircuitParams { + constants_table, + xor_table, + xor_table_batch3, + xor_table_batch4, + chi_table, + pack_table, + step_num, + }; + + let maxwidth_config = config( + MaxWidthCellManager::new(198, true), + SimpleStepSelectorBuilder {}, + ); + let (keccak, _) = ctx.sub_circuit(maxwidth_config, keccak_circuit, params); + + ctx.mapping(move |ctx, values| { + ctx.map(&keccak, values); + }) + }) +} + +use chiquito::plonkish::backend::plaf::chiquito2Plaf; +use polyexen::plaf::{Plaf, PlafDisplayBaseTOML, PlafDisplayFixedCSV, Witness, WitnessDisplayCSV}; + +fn write_files(name: &str, plaf: &Plaf, wit: &Witness) -> Result<(), io::Error> { + let mut base_file = File::create(format!("{}.toml", name))?; + let mut fixed_file = File::create(format!("{}_fixed.csv", name))?; + let mut witness_file = File::create(format!("{}_witness.csv", name))?; + + write!(base_file, "{}", PlafDisplayBaseTOML(plaf))?; + write!(fixed_file, "{}", PlafDisplayFixedCSV(plaf))?; + write!(witness_file, "{}", WitnessDisplayCSV(wit))?; + println!("write file success...{}", name); + Ok(()) +} + +fn keccak_plaf(circuit_param: KeccakCircuit, k: u32) { + let super_circuit = keccak_super_circuit::(circuit_param.bytes.len()); + let witness = super_circuit.get_mapping().generate(circuit_param); + + for wit_gen in witness.values() { + let wit_gen = wit_gen.clone(); + + let mut circuit = super_circuit.get_sub_circuits()[0].clone(); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[1].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[2].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[3].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[4].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[5].columns); + circuit + .columns + .append(&mut super_circuit.get_sub_circuits()[6].columns); + + for (key, value) in super_circuit.get_sub_circuits()[0].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[1].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[2].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[3].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[4].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + for (key, value) in super_circuit.get_sub_circuits()[5].fixed_assignments.iter() { + circuit.fixed_assignments.insert(key.clone(), value.clone()); + } + + let (plaf, plaf_wit_gen) = chiquito2Plaf(circuit, k, false); + + let mut plaf = plaf; + plaf.set_challenge_alias(0, "r_keccak".to_string()); + let wit = plaf_wit_gen.generate(Some(wit_gen)); + write_files("keccak_output", &plaf, &wit).unwrap(); + } +} + +fn keccak_run(circuit_param: KeccakCircuit, k: u32) -> bool { + let super_circuit = keccak_super_circuit::(circuit_param.bytes.len()); + + let compiled = chiquitoSuperCircuit2Halo2(&super_circuit); + + let circuit = ChiquitoHalo2SuperCircuit::new( + compiled, + super_circuit.get_mapping().generate(circuit_param), + ); + + let prover = MockProver::::run(k, &circuit, Vec::new()).unwrap(); + let result = prover.verify(); + + println!("result = {:#?}", result); + + if let Err(failures) = &result { + for failure in failures.iter() { + println!("{}", failure); + } + false + } else { + true + } +} + +fn main() { + let circuit_param = KeccakCircuit { + bytes: vec![0, 1, 2, 3, 4, 5, 6, 7], + }; + + let res = keccak_run(circuit_param, 9); + + if res { + keccak_plaf( + KeccakCircuit { + bytes: vec![0, 1, 2, 3, 4, 5, 6, 7], + }, + 11, + ); + } +} diff --git a/examples/mimc7.rs b/examples/mimc7.rs index fa1b1d2e..595e02b1 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -175,7 +175,7 @@ fn mimc7_circuit( row_value += F::from(1); x_value += k_value + c_value; x_value = x_value.pow_vartime([7_u64]); - // Step 90: output the hash result as x + k in witness generation + // Step 91: output the hash result as x + k in witness generation // output is not displayed as a public column, which will be implemented in the future ctx.add(&mimc7_last_step, (x_value, k_value, c_value, row_value)); // c_value is not // used here but @@ -210,7 +210,7 @@ fn main() { let prover = MockProver::::run(10, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); @@ -219,6 +219,29 @@ fn main() { println!("{}", failure); } } + + // pil boilerplate + use chiquito::pil::backend::powdr_pil::chiquitoSuperCircuit2Pil; + + let x_in_value = Fr::from_str_vartime("1").expect("expected a number"); + let k_value = Fr::from_str_vartime("2").expect("expected a number"); + + let super_circuit = mimc7_super_circuit::(); + + // `super_trace_witnesses` is a mapping from IR id to TraceWitness. However, not all ASTs have a + // corresponding TraceWitness. + let super_trace_witnesses = super_circuit + .get_mapping() + .generate_super_trace_witnesses((x_in_value, k_value)); + + let pil = chiquitoSuperCircuit2Pil::( + super_circuit.get_super_asts(), + super_trace_witnesses, + super_circuit.get_ast_id_to_ir_id_mapping(), + vec![String::from("Mimc7Constant"), String::from("Mimc7Circuit")], + ); + + print!("{}", pil); } mod mimc7_constants { diff --git a/examples/poseidon.rs b/examples/poseidon.rs index 0b626235..aeb476e3 100644 --- a/examples/poseidon.rs +++ b/examples/poseidon.rs @@ -11,7 +11,7 @@ use chiquito::{ }, sbpir::query::Queriable, }; -// use halo2curves::ff::Field; + use std::hash::Hash; use halo2_proofs::{ @@ -484,7 +484,6 @@ fn poseidon_circuit( x_value }) .collect(); - let mut sbox_values: Vec = x_values .iter() .map(|x_value| *x_value * x_value * x_value * x_value * x_value) @@ -513,7 +512,6 @@ fn poseidon_circuit( out_values: outputs.clone(), round: F::ZERO, }; - ctx.add(&poseidon_step_first_round, round_values); inputs = outputs; @@ -712,7 +710,7 @@ fn main() { let prover = MockProver::::run(12, &circuit, Vec::new()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); diff --git a/rust-toolchain b/rust-toolchain index b6ce6a50..c5f61037 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2023-04-24 +nightly-2024-02-14 diff --git a/src/frontend/dsl/cb.rs b/src/frontend/dsl/cb.rs index 5b2f577f..d41b041b 100644 --- a/src/frontend/dsl/cb.rs +++ b/src/frontend/dsl/cb.rs @@ -716,4 +716,51 @@ mod tests { matches!(v[1], Expr::Const(c) if c == 40u64.field())) && matches!(v[1], Expr::Const(c) if c == 10u64.field()))); } + + #[test] + fn test_constraint_from_queriable() { + // Create a Queriable instance and convert it to a Constraint + let queriable = Queriable::StepTypeNext(StepTypeHandler::new("test_step".to_owned())); + let constraint: Constraint = Constraint::from(queriable); + + assert_eq!(constraint.annotation, "test_step"); + assert!( + matches!(constraint.expr, Expr::Query(Queriable::StepTypeNext(s)) if + matches!(s, StepTypeHandler {id: _id, annotation: "test_step"})) + ); + assert!(matches!(constraint.typing, Typing::Boolean)); + } + + #[test] + fn test_constraint_from_expr() { + // Create an expression and convert it to a Constraint + let expr = >>::expr(&10) * 20u64.expr(); + let constraint: Constraint = Constraint::from(expr); + + // returns "10 * 20" + assert!(matches!(constraint.expr, Expr::Mul(v) if v.len() == 2 && + matches!(v[0], Expr::Const(c) if c == 10u64.field()) && + matches!(v[1], Expr::Const(c) if c == 20u64.field()))); + assert!(matches!(constraint.typing, Typing::Unknown)); + } + + #[test] + fn test_constraint_from_int() { + // Create an integer and convert it to a Constraint + let constraint: Constraint = Constraint::from(10); + + // returns "10" + assert!(matches!(constraint.expr, Expr::Const(c) if c == 10u64.field())); + assert!(matches!(constraint.typing, Typing::Unknown)); + } + + #[test] + fn test_constraint_from_bool() { + // Create a boolean and convert it to a Constraint + let constraint: Constraint = Constraint::from(true); + + assert_eq!(constraint.annotation, "0x1"); + assert!(matches!(constraint.expr, Expr::Const(c) if c == 1u64.field())); + assert!(matches!(constraint.typing, Typing::Unknown)); + } } diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index c90db77e..e52313b5 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -154,6 +154,8 @@ impl CircuitContext { self.circuit.last_step = Some(step_type.into().uuid()); } + /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` + /// parameter that represents the total number of steps. pub fn pragma_num_steps(&mut self, num_steps: usize) { self.circuit.num_steps = num_steps; } @@ -201,6 +203,12 @@ impl From<&'static str> for StepTypeDefInput { } } +impl From for StepTypeDefInput { + fn from(s: String) -> Self { + StepTypeDefInput::String(Box::leak(s.into_boxed_str())) + } +} + /// A generic structure designed to handle the context of a step type definition. The struct /// contains a `StepType` instance and implements methods to build the step type, add components, /// and manipulate the step type. `F` is a generic type representing the field of the step type. @@ -225,6 +233,7 @@ impl StepTypeContext { } /// DEPRECATED + // #[deprecated(note = "use step types setup for constraints instead")] pub fn constr>>(&mut self, constraint: C) { println!("DEPRECATED constr: use setup for constraints in step types"); @@ -235,6 +244,7 @@ impl StepTypeContext { } /// DEPRECATED + #[deprecated(note = "use step types setup for constraints instead")] pub fn transition>>(&mut self, constraint: C) { println!("DEPRECATED transition: use setup for constraints in step types"); @@ -358,6 +368,10 @@ impl StepTypeHandler { pub fn next(&self) -> Queriable { Queriable::StepTypeNext(*self) } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } impl, Args) + 'static> From<&StepTypeWGHandler> @@ -420,28 +434,49 @@ pub mod sc; #[cfg(test)] mod tests { + use crate::sbpir::ForwardSignal; + use super::*; + fn setup_circuit_context() -> CircuitContext + where + F: Default, + TraceArgs: Default, + { + CircuitContext { + circuit: SBPIR::default(), + tables: Default::default(), + } + } + #[test] - fn test_disable_q_enable() { + fn test_circuit_default_initialization() { let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; - context.pragma_disable_q_enable(); + // Assert default values + assert!(circuit.step_types.is_empty()); + assert!(circuit.forward_signals.is_empty()); + assert!(circuit.shared_signals.is_empty()); + assert!(circuit.fixed_signals.is_empty()); + assert!(circuit.exposed.is_empty()); + assert!(circuit.annotations.is_empty()); + assert!(circuit.trace.is_none()); + assert!(circuit.first_step.is_none()); + assert!(circuit.last_step.is_none()); + assert!(circuit.num_steps == 0); + assert!(circuit.q_enable); + } + #[test] + fn test_disable_q_enable() { + let mut context = setup_circuit_context::(); + context.pragma_disable_q_enable(); assert!(!context.circuit.q_enable); } #[test] fn test_set_num_steps() { - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); context.pragma_num_steps(3); assert_eq!(context.circuit.num_steps, 3); @@ -450,14 +485,29 @@ mod tests { assert_eq!(context.circuit.num_steps, 0); } + #[test] + fn test_set_first_step() { + let mut context = setup_circuit_context::(); + + let step_type: StepTypeHandler = context.step_type("step_type"); + + context.pragma_first_step(step_type); + assert_eq!(context.circuit.first_step, Some(step_type.uuid())); + } + + #[test] + fn test_set_last_step() { + let mut context = setup_circuit_context::(); + + let step_type: StepTypeHandler = context.step_type("step_type"); + + context.pragma_last_step(step_type); + assert_eq!(context.circuit.last_step, Some(step_type.uuid())); + } + #[test] fn test_forward() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signals let forward_a: Queriable = context.forward("forward_a"); @@ -469,14 +519,21 @@ mod tests { assert_eq!(context.circuit.forward_signals[1].uuid(), forward_b.uuid()); } + #[test] + fn test_adding_duplicate_signal_names() { + let mut context = setup_circuit_context::(); + context.forward("duplicate_name"); + context.forward("duplicate_name"); + // Assert how the system should behave. Does it override the previous signal, throw an + // error, or something else? + // TODO: Should we let the user know that they are adding a duplicate signal name? And let + // the circuit have two signals with the same name? + assert_eq!(context.circuit.forward_signals.len(), 2); + } + #[test] fn test_forward_with_phase() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signals with specified phase context.forward_with_phase("forward_a", 1); @@ -490,12 +547,7 @@ mod tests { #[test] fn test_shared() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set shared signal let shared_a: Queriable = context.shared("shared_a"); @@ -507,12 +559,7 @@ mod tests { #[test] fn test_shared_with_phase() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set shared signal with specified phase context.shared_with_phase("shared_a", 2); @@ -524,12 +571,7 @@ mod tests { #[test] fn test_fixed() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set fixed signal context.fixed("fixed_a"); @@ -540,12 +582,7 @@ mod tests { #[test] fn test_expose() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signal and step to expose let forward_a: Queriable = context.forward("forward_a"); @@ -562,14 +599,18 @@ mod tests { ); } + #[test] + #[should_panic(expected = "Signal not found")] + fn test_expose_non_existing_signal() { + let mut context = setup_circuit_context::(); + let non_existing_signal = + Queriable::Forward(ForwardSignal::new_with_phase(0, "".to_owned()), false); // Create a signal not added to the circuit + context.expose(non_existing_signal, ExposeOffset::First); + } + #[test] fn test_step_type() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type let handler: StepTypeHandler = context.step_type("fibo_first_step"); @@ -583,12 +624,7 @@ mod tests { #[test] fn test_step_type_def() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type including its definition let simple_step = context.step_type_def("simple_step", |context| { @@ -609,12 +645,7 @@ mod tests { #[test] fn test_step_type_def_pass_handler() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type handler let handler: StepTypeHandler = context.step_type("simple_step"); @@ -635,4 +666,23 @@ mod tests { context.circuit.step_types[&simple_step.uuid()].uuid() ); } + + #[test] + fn test_trace() { + let mut context = setup_circuit_context::(); + + // set trace function + context.trace(|_, _: i32| {}); + + // assert trace function was set + assert!(context.circuit.trace.is_some()); + } + + #[test] + #[should_panic(expected = "circuit cannot have more than one trace generator")] + fn test_setting_trace_multiple_times() { + let mut context = setup_circuit_context::(); + context.trace(|_, _| {}); + context.trace(|_, _| {}); + } } diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 9b6daec9..3997ca11 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -18,6 +18,7 @@ use crate::{ use super::{lb::LookupTableRegistry, CircuitContext}; +#[derive(Debug)] pub struct SuperCircuitContext { super_circuit: SuperCircuit, sub_circuit_phase1: Vec>, @@ -34,6 +35,12 @@ impl Default for SuperCircuitContext { } } +impl SuperCircuitContext { + fn add_sub_circuit_ast(&mut self, ast: SBPIR) { + self.super_circuit.add_sub_circuit_ast(ast); + } +} + impl SuperCircuitContext { pub fn sub_circuit( &mut self, @@ -48,12 +55,13 @@ impl SuperCircuitContext { circuit: SBPIR::default(), tables: self.tables.clone(), }; - println!("super circuit table registry 2: {:?}", self.tables); let exports = sub_circuit_def(&mut sub_circuit_context, imports); - println!("super circuit table registry 3: {:?}", self.tables); let sub_circuit = sub_circuit_context.circuit; + // ast is used for PIL backend + self.add_sub_circuit_ast(sub_circuit.clone_without_trace()); + let (unit, assignment) = compile_phase1(config, &sub_circuit); let assignment = assignment.unwrap_or_else(|| AssignmentGenerator::empty(unit.uuid)); @@ -113,3 +121,205 @@ where ctx.compile() } + +#[cfg(test)] +mod tests { + use halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; + + use crate::{ + plonkish::compiler::{ + cell_manager::SingleRowCellManager, config, step_selector::SimpleStepSelectorBuilder, + }, + poly::ToField, + }; + + use super::*; + + #[test] + fn test_super_circuit_context_default() { + let ctx = SuperCircuitContext::::default(); + + assert_eq!( + format!("{:#?}", ctx.super_circuit), + format!("{:#?}", SuperCircuit::::default()) + ); + assert_eq!( + format!("{:#?}", ctx.sub_circuit_phase1), + format!("{:#?}", Vec::>::default()) + ); + assert_eq!(ctx.sub_circuit_phase1.len(), 0); + assert_eq!( + format!("{:#?}", ctx.tables), + format!("{:#?}", LookupTableRegistry::::default()) + ); + } + + #[test] + fn test_super_circuit_context_sub_circuit() { + let mut ctx = SuperCircuitContext::::default(); + + fn simple_circuit(ctx: &mut CircuitContext, _: ()) { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }) + } + + // simple circuit to check if the sum of two inputs are 10 + ctx.sub_circuit( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit, + (), + ); + + // ensure phase 1 was done correctly for the sub circuit + assert_eq!(ctx.sub_circuit_phase1.len(), 1); + assert_eq!(ctx.sub_circuit_phase1[0].columns.len(), 4); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!(ctx.sub_circuit_phase1[0].columns[2].annotation, "q_enable"); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + assert_eq!(ctx.sub_circuit_phase1[0].forward_signals.len(), 2); + assert_eq!(ctx.sub_circuit_phase1[0].step_types.len(), 1); + assert_eq!(ctx.sub_circuit_phase1[0].compilation_phase, 1); + } + + #[test] + fn test_super_circuit_compile() { + let mut ctx = SuperCircuitContext::::default(); + + fn simple_circuit(ctx: &mut CircuitContext, _: ()) { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }) + } + + // simple circuit to check if the sum of two inputs are 10 + ctx.sub_circuit( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit, + (), + ); + + let super_circuit = ctx.compile(); + + assert_eq!(super_circuit.get_sub_circuits().len(), 1); + assert_eq!(super_circuit.get_sub_circuits()[0].columns.len(), 4); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[2].annotation, + "q_enable" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + } + + #[test] + fn test_super_circuit_sub_circuit_with_ast() { + use crate::frontend::dsl::circuit; + let mut ctx = SuperCircuitContext::::default(); + + let simple_circuit_with_ast = circuit("simple circuit", |ctx| { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }); + }); + + ctx.sub_circuit_with_ast( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit_with_ast, + ); + + let super_circuit = ctx.compile(); + + assert_eq!(super_circuit.get_sub_circuits().len(), 1); + assert_eq!(super_circuit.get_sub_circuits()[0].columns.len(), 4); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[2].annotation, + "q_enable" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + } +} diff --git a/src/frontend/python/chiquito/cb.py b/src/frontend/python/chiquito/cb.py index 6d25d920..5828fe09 100644 --- a/src/frontend/python/chiquito/cb.py +++ b/src/frontend/python/chiquito/cb.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum, auto -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from chiquito.util import F, uuid from chiquito.expr import Expr, Const, Neg, to_expr, ToExpr @@ -205,7 +205,7 @@ def table() -> LookupTable: return LookupTable() -ToConstraint = Constraint | Expr | int | F +ToConstraint = Union[Constraint, Expr, int, F] def to_constraint(v: ToConstraint) -> Constraint: diff --git a/src/frontend/python/chiquito/chiquito_ast.py b/src/frontend/python/chiquito/chiquito_ast.py index fc6e6d0b..3b44505a 100644 --- a/src/frontend/python/chiquito/chiquito_ast.py +++ b/src/frontend/python/chiquito/chiquito_ast.py @@ -126,7 +126,7 @@ def __json__(self: ASTCircuit): "last_step": self.last_step, "num_steps": self.num_steps, "q_enable": self.q_enable, - "id": self.id, + "id": self.id.__str__(), } def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal: diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index aa355972..51f089c7 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List, Dict +from typing import List, Dict, Union from enum import Enum from typing import Callable, Any @@ -225,6 +225,15 @@ def halo2_mock_prover(self: Circuit, witness: TraceWitness, k: int = 16): witness_json: str = witness.get_witness_json() rust_chiquito.halo2_mock_prover(witness_json, self.rust_id, k) + def to_pil( + self: Circuit, witness: TraceWitness, circuit_name: str = "Circuit" + ) -> str: + if self.rust_id == 0: + ast_json: str = self.get_ast_json() + self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + witness_json: str = witness.get_witness_json() + rust_chiquito.to_pil(witness_json, self.rust_id, circuit_name) + def __str__(self: Circuit) -> str: return self.ast.__str__() @@ -286,4 +295,4 @@ def add_lookup(self: StepType, lookup_builder: LookupBuilder): self.step_type.lookups.append(lookup) -LookupBuilder = LookupTableBuilder | InPlaceLookupBuilder +LookupBuilder = Union[LookupTableBuilder, InPlaceLookupBuilder] diff --git a/src/frontend/python/chiquito/expr.py b/src/frontend/python/chiquito/expr.py index 22e11396..41f4d93f 100644 --- a/src/frontend/python/chiquito/expr.py +++ b/src/frontend/python/chiquito/expr.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List +from typing import List, Union from dataclasses import dataclass from chiquito.util import F @@ -141,7 +141,7 @@ def __json__(self): return {"Pow": [self.expr.__json__(), self.pow]} -ToExpr = Expr | int | F +ToExpr = Union[Expr, int, F] def to_expr(v: ToExpr) -> Expr: diff --git a/src/frontend/python/chiquito/query.py b/src/frontend/python/chiquito/query.py index 9dafb6c2..0eeb52e7 100644 --- a/src/frontend/python/chiquito/query.py +++ b/src/frontend/python/chiquito/query.py @@ -134,5 +134,5 @@ def __str__(self: ASTStepType) -> str: def __json__(self): return { - "StepTypeNext": {"id": self.step_type.id, "annotation": self.step_type.name} + "StepTypeNext": {"id": f"{self.step_type.id}", "annotation": self.step_type.name} } diff --git a/src/frontend/python/chiquito/util.py b/src/frontend/python/chiquito/util.py index 0533fd6f..a9621acb 100644 --- a/src/frontend/python/chiquito/util.py +++ b/src/frontend/python/chiquito/util.py @@ -14,11 +14,9 @@ def __json__(self: F): # Convert the integer to a byte array montgomery_form = self.n * R % F.field_modulus byte_array = montgomery_form.to_bytes(32, "little") - # Split into four 64-bit integers - ints = [ - int.from_bytes(byte_array[i * 8 : i * 8 + 8], "little") for i in range(4) - ] - return ints + + # return the hex string + return byte_array.hex() class CustomEncoder(json.JSONEncoder): @@ -29,5 +27,5 @@ def default(self, obj): # int field is the u128 version of uuid. -def uuid() -> int: - return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int +def uuid() -> str: + return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int.__str__() diff --git a/src/frontend/python/chiquito/wit_gen.py b/src/frontend/python/chiquito/wit_gen.py index cec59905..2ab1bf39 100644 --- a/src/frontend/python/chiquito/wit_gen.py +++ b/src/frontend/python/chiquito/wit_gen.py @@ -41,7 +41,7 @@ def __str__(self: StepInstance): # For assignments, return "uuid: (Queriable, F)" rather than "Queriable: F", because JSON doesn't accept Dict as key. def __json__(self: StepInstance): return { - "step_type_uuid": self.step_type_uuid, + "step_type_uuid": self.step_type_uuid.__str__(), "assignments": { lhs.uuid(): [lhs, rhs] for (lhs, rhs) in self.assignments.items() }, diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index e6da88aa..32a25726 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -2,9 +2,11 @@ use pyo3::{ prelude::*, types::{PyDict, PyList, PyLong, PyString}, }; +use serde_json::{from_str, Value}; use crate::{ frontend::dsl::{StepTypeHandler, SuperCircuitContext}, + pil::backend::powdr_pil::chiquito2Pil, plonkish::{ backend::halo2::{ chiquito2Halo2, chiquitoSuperCircuit2Halo2, ChiquitoHalo2, ChiquitoHalo2Circuit, @@ -46,8 +48,10 @@ thread_local! { /// as the key. Return the Rust UUID to Python. The last field of the tuple, `TraceWitness`, is left /// as None, for `chiquito_add_witness_to_rust_id` to insert. pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { + let value: Value = from_str(ast_json).expect("Invalid JSON"); + // Attempt to convert `Value` into `SBPIR` let circuit: SBPIR = - serde_json::from_str(ast_json).expect("Json deserialization to Circuit failed."); + serde_json::from_value(value).expect("Deserialization to Circuit failed."); let config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); let (chiquito, assignment_generator) = compile(config, &circuit); @@ -81,6 +85,14 @@ pub fn chiquito_ast_map_store(ast_json: &str) -> UUID { uuid } +pub fn chiquito_ast_to_pil(witness_json: &str, rust_id: UUID, circuit_name: &str) -> String { + let trace_witness: TraceWitness = + serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); + let (ast, _, _) = rust_id_to_halo2(rust_id); + + chiquito2Pil(ast, Some(trace_witness), circuit_name.to_string()) +} + fn add_assignment_generator_to_rust_id( assignment_generator: AssignmentGenerator, rust_id: UUID, @@ -133,7 +145,7 @@ pub fn chiquito_super_circuit_halo2_mock_prover( let prover = MockProver::::run(k as u32, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("result = {:#?}", result); @@ -166,7 +178,7 @@ pub fn chiquito_halo2_mock_prover(witness_json: &str, rust_id: UUID, k: usize) { let prover = MockProver::::run(k as u32, &circuit, circuit.instance()).unwrap(); - let result = prover.verify_par(); + let result = prover.verify(); println!("{:#?}", result); @@ -203,13 +215,18 @@ impl<'de> Visitor<'de> for CircuitVisitor { let mut q_enable = None; let mut id = None; + println!("------ Visiting map -------"); + while let Some(key) = map.next_key::()? { + println!("key = {}", key); match key.as_str() { "step_types" => { + println!("------ Visiting step_types -------"); if step_types.is_some() { return Err(de::Error::duplicate_field("step_types")); } step_types = Some(map.next_value::>>()?); + println!("step_types = {:#?}", step_types); } "forward_signals" => { if forward_signals.is_some() { @@ -252,13 +269,33 @@ impl<'de> Visitor<'de> for CircuitVisitor { if first_step.is_some() { return Err(de::Error::duplicate_field("first_step")); } - first_step = Some(map.next_value::>()?); + let first_step_opt: Option = map.next_value()?; // Deserialize the value as an optional string + first_step = Some(first_step_opt.map_or(Ok(None), |first_step_str| { + StepTypeUUID::from_str_radix(&first_step_str, 10) + .map(Some) + .map_err(|e| { + de::Error::custom(format!( + "Failed to parse first_step '{}': {}", + first_step_str, e + )) + }) + })?); } "last_step" => { if last_step.is_some() { return Err(de::Error::duplicate_field("last_step")); } - last_step = Some(map.next_value::>()?); + let last_step_opt: Option = map.next_value()?; // Deserialize the value as an optional string + last_step = Some(last_step_opt.map_or(Ok(None), |last_step_str| { + StepTypeUUID::from_str_radix(&last_step_str, 10) + .map(Some) + .map_err(|e| { + de::Error::custom(format!( + "Failed to parse last_step '{}': {}", + last_step_str, e + )) + }) + })?); } "num_steps" => { if num_steps.is_some() { @@ -276,7 +313,10 @@ impl<'de> Visitor<'de> for CircuitVisitor { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } _ => { return Err(de::Error::unknown_field( @@ -367,7 +407,10 @@ impl<'de> Visitor<'de> for StepTypeVisitor { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } "name" => { if name.is_some() { @@ -620,6 +663,7 @@ impl<'de> Visitor<'de> for QueriableVisitor { let key: String = map .next_key()? .ok_or_else(|| de::Error::custom("map is empty"))?; + match key.as_str() { "Internal" => map.next_value().map(Queriable::Internal), "Forward" => map @@ -628,9 +672,11 @@ impl<'de> Visitor<'de> for QueriableVisitor { "Shared" => map .next_value() .map(|(signal, rotation)| Queriable::Shared(signal, rotation)), - "Fixed" => map - .next_value() - .map(|(signal, rotation)| Queriable::Fixed(signal, rotation)), + "Fixed" => { + println!("Processing Fixed"); + map.next_value() + .map(|(signal, rotation)| Queriable::Fixed(signal, rotation)) + } "StepTypeNext" => map.next_value().map(Queriable::StepTypeNext), _ => Err(de::Error::unknown_variant( &key, @@ -694,7 +740,10 @@ macro_rules! impl_visitor_internal_fixed_steptypehandler { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; // Get the UUID as a string + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } "annotation" => { if annotation.is_some() { @@ -750,7 +799,10 @@ macro_rules! impl_visitor_forward_shared { if id.is_some() { return Err(de::Error::duplicate_field("id")); } - id = Some(map.next_value()?); + let id_str: String = map.next_value()?; // Get the UUID as a string + id = Some(id_str.parse::().map_err(|e| { + de::Error::custom(format!("Failed to parse id '{}': {}", id_str, e)) + })?); } "phase" => { if phase.is_some() { @@ -839,7 +891,12 @@ impl<'de> Visitor<'de> for StepInstanceVisitor { if step_type_uuid.is_some() { return Err(de::Error::duplicate_field("step_type_uuid")); } - step_type_uuid = Some(map.next_value()?); + let uuid_str: String = map.next_value()?; // Get the UUID as a string + step_type_uuid = Some( + uuid_str + .parse::() // Assuming the string is in decimal format + .map_err(de::Error::custom)?, + ); } "assignments" => { if assignments.is_some() { @@ -910,118 +967,90 @@ impl<'de> Deserialize<'de> for SBPIR { #[cfg(test)] mod tests { use super::*; + #[test] + #[ignore] fn test_trace_witness() { let json = r#" { "step_instances": [ { - "step_type_uuid": 270606747459021742275781620564109167114, + "step_type_uuid": "270606747459021742275781620564109167114", "assignments": { "270606737951642240564318377467548666378": [ { "Forward": [ { - "id": 270606737951642240564318377467548666378, + "id": "270606737951642240564318377467548666378", "phase": 0, "annotation": "a" }, false ] }, - [ - 55, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000055" ], "270606743497613616562965561253747624458": [ { "Forward": [ { - "id": 270606743497613616562965561253747624458, + "id": "270606743497613616562965561253747624458", "phase": 0, "annotation": "b" }, false ] }, - [ - 89, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000089" ], "270606753004993118272949371872716917258": [ { "Internal": { - "id": 270606753004993118272949371872716917258, + "id": "270606753004993118272949371872716917258", "annotation": "c" } }, - [ - 144, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000144" ] } }, { - "step_type_uuid": 270606783111694873693576112554652600842, + "step_type_uuid": "270606783111694873693576112554652600842", "assignments": { "270606737951642240564318377467548666378": [ { "Forward": [ { - "id": 270606737951642240564318377467548666378, + "id": "270606737951642240564318377467548666378", "phase": 0, "annotation": "a" }, false ] }, - [ - 89, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000089" ], "270606743497613616562965561253747624458": [ { "Forward": [ { - "id": 270606743497613616562965561253747624458, + "id": "270606743497613616562965561253747624458", "phase": 0, "annotation": "b" }, false ] }, - [ - 144, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000144" ], "270606786280821374261518951164072823306": [ { "Internal": { - "id": 270606786280821374261518951164072823306, + "id": "270606786280821374261518951164072823306", "annotation": "c" } }, - [ - 233, - 0, - 0, - 0 - ] + "0000000000000000000000000000000000000000000000000000000000000233" ] } } @@ -1060,11 +1089,11 @@ mod tests { { "step_types": { "258869595755756204079859764249309612554": { - "id": 258869595755756204079859764249309612554, + "id": "258869595755756204079859764249309612554", "name": "fibo_first_step", "signals": [ { - "id": 258869599717164329791616633222308956682, + "id": "258869599717164329791616633222308956682", "annotation": "c" } ], @@ -1076,7 +1105,7 @@ mod tests { { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1085,12 +1114,7 @@ mod tests { }, { "Neg": { - "Const": [ - 1, - 0, - 0, - 0 - ] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" } } ] @@ -1103,7 +1127,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1112,12 +1136,7 @@ mod tests { }, { "Neg": { - "Const": [ - 1, - 0, - 0, - 0 - ] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" } } ] @@ -1130,7 +1149,7 @@ mod tests { { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1140,7 +1159,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1150,7 +1169,7 @@ mod tests { { "Neg": { "Internal": { - "id": 258869599717164329791616633222308956682, + "id": "258869599717164329791616633222308956682", "annotation": "c" } } @@ -1167,7 +1186,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1178,7 +1197,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1195,7 +1214,7 @@ mod tests { "Sum": [ { "Internal": { - "id": 258869599717164329791616633222308956682, + "id": "258869599717164329791616633222308956682", "annotation": "c" } }, @@ -1203,7 +1222,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1221,7 +1240,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1232,7 +1251,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1250,11 +1269,11 @@ mod tests { } }, "258869628239302834927102989021255174666": { - "id": 258869628239302834927102989021255174666, + "id": "258869628239302834927102989021255174666", "name": "fibo_step", "signals": [ { - "id": 258869632200710960639812650790420089354, + "id": "258869632200710960639812650790420089354", "annotation": "c" } ], @@ -1266,7 +1285,7 @@ mod tests { { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1276,7 +1295,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1286,7 +1305,7 @@ mod tests { { "Neg": { "Internal": { - "id": 258869632200710960639812650790420089354, + "id": "258869632200710960639812650790420089354", "annotation": "c" } } @@ -1303,7 +1322,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1314,7 +1333,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, @@ -1331,7 +1350,7 @@ mod tests { "Sum": [ { "Internal": { - "id": 258869632200710960639812650790420089354, + "id": "258869632200710960639812650790420089354", "annotation": "c" } }, @@ -1339,7 +1358,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1357,7 +1376,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1368,7 +1387,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1386,7 +1405,7 @@ mod tests { } }, "258869646461780213207493341245063432714": { - "id": 258869646461780213207493341245063432714, + "id": "258869646461780213207493341245063432714", "name": "padding", "signals": [], "constraints": [], @@ -1398,7 +1417,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1409,7 +1428,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1427,7 +1446,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1438,7 +1457,7 @@ mod tests { "Neg": { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1456,17 +1475,17 @@ mod tests { }, "forward_signals": [ { - "id": 258869580702405326369584955980151130634, + "id": "258869580702405326369584955980151130634", "phase": 0, "annotation": "a" }, { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" } @@ -1478,7 +1497,7 @@ mod tests { { "Forward": [ { - "id": 258869587040658327507391136965088381450, + "id": "258869587040658327507391136965088381450", "phase": 0, "annotation": "b" }, @@ -1493,7 +1512,7 @@ mod tests { { "Forward": [ { - "id": 258869589417503202934383108674030275082, + "id": "258869589417503202934383108674030275082", "phase": 0, "annotation": "n" }, @@ -1514,11 +1533,11 @@ mod tests { "258869646461780213207493341245063432714": "padding" }, "fixed_assignments": null, - "first_step": 258869595755756204079859764249309612554, - "last_step": 258869646461780213207493341245063432714, + "first_step": "258869595755756204079859764249309612554", + "last_step": "258869646461780213207493341245063432714", "num_steps": 10, "q_enable": true, - "id": 258867373405797678961444396351437277706 + "id": "258867373405797678961444396351437277706" } "#; let circuit: SBPIR = serde_json::from_str(json).unwrap(); @@ -1529,15 +1548,15 @@ mod tests { fn test_step_type() { let json = r#" { - "id":1, + "id":"1", "name":"fibo", "signals":[ { - "id":1, + "id":"1", "annotation":"a" }, { - "id":2, + "id":"2", "annotation":"b" } ], @@ -1547,18 +1566,18 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Internal":{ - "id":3, + "id":"3", "annotation":"c" } }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1570,14 +1589,14 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Shared":[ { - "id":4, + "id":"4", "phase":2, "annotation":"d" }, @@ -1585,7 +1604,7 @@ mod tests { ] }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1599,14 +1618,14 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Forward":[ { - "id":5, + "id":"5", "phase":1, "annotation":"e" }, @@ -1614,7 +1633,7 @@ mod tests { ] }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1626,21 +1645,21 @@ mod tests { "expr":{ "Sum":[ { - "Const":[1, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000001" }, { "Mul":[ { "Fixed":[ { - "id":6, + "id":"6", "annotation":"e" }, 2 ] }, { - "Const":[3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" } ] } @@ -1669,14 +1688,14 @@ mod tests { "Sum": [ { "Internal": { - "id": 27, + "id": "27", "annotation": "a" } }, { "Fixed": [ { - "id": 28, + "id": "28", "annotation": "b" }, 1 @@ -1685,7 +1704,7 @@ mod tests { { "Shared": [ { - "id": 29, + "id": "29", "phase": 1, "annotation": "c" }, @@ -1695,7 +1714,7 @@ mod tests { { "Forward": [ { - "id": 30, + "id": "30", "phase": 2, "annotation": "d" }, @@ -1704,32 +1723,32 @@ mod tests { }, { "StepTypeNext": { - "id": 31, + "id": "31", "annotation": "e" } }, { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, { "Mul": [ { - "Const": [4, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000004" }, { - "Const": [5, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000005" } ] }, { "Neg": { - "Const": [2, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000002" } }, { "Pow": [ { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, 4 ] @@ -1750,14 +1769,14 @@ mod tests { "Sum": [ { "Internal": { - "id": 27, + "id": "27", "annotation": "a" } }, { "Fixed": [ { - "id": 28, + "id": "28", "annotation": "b" }, 1 @@ -1766,7 +1785,7 @@ mod tests { { "Shared": [ { - "id": 29, + "id": "29", "phase": 1, "annotation": "c" }, @@ -1776,7 +1795,7 @@ mod tests { { "Forward": [ { - "id": 30, + "id": "30", "phase": 2, "annotation": "d" }, @@ -1785,32 +1804,32 @@ mod tests { }, { "StepTypeNext": { - "id": 31, + "id": "31", "annotation": "e" } }, { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, { "Mul": [ { - "Const": [4, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000004" }, { - "Const": [5, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000005" } ] }, { "Neg": { - "Const": [2, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000002" } }, { "Pow": [ { - "Const": [3, 0, 0, 0] + "Const": "0000000000000000000000000000000000000000000000000000000000000003" }, 4 ] @@ -1845,6 +1864,18 @@ fn ast_to_halo2(json: &PyString) -> u128 { uuid } +#[pyfunction] +fn to_pil(witness_json: &PyString, rust_id: &PyLong, circuit_name: &PyString) -> String { + let pil = chiquito_ast_to_pil( + witness_json.to_str().expect("PyString convertion failed."), + rust_id.extract().expect("PyLong convertion failed."), + circuit_name.to_str().expect("PyString convertion failed."), + ); + + println!("{}", pil); + pil +} + #[pyfunction] fn ast_map_store(json: &PyString) -> u128 { let uuid = chiquito_ast_map_store(json.to_str().expect("PyString conversion failed.")); @@ -1903,6 +1934,7 @@ fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?; m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; + m.add_function(wrap_pyfunction!(to_pil, m)?)?; m.add_function(wrap_pyfunction!(ast_map_store, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; m.add_function(wrap_pyfunction!(super_circuit_halo2_mock_prover, m)?)?; diff --git a/src/lib.rs b/src/lib.rs index dd823581..853e71ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod field; pub mod frontend; +pub mod pil; pub mod plonkish; pub mod poly; pub mod sbpir; diff --git a/src/pil/backend/mod.rs b/src/pil/backend/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/backend/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/backend/powdr_pil.rs b/src/pil/backend/powdr_pil.rs new file mode 100644 index 00000000..0585c45c --- /dev/null +++ b/src/pil/backend/powdr_pil.rs @@ -0,0 +1,226 @@ +use crate::{ + field::Field, + pil::{ + compiler::{compile, compile_super_circuits, PILColumn, PILExpr, PILQuery}, + ir::powdr_pil::PILCircuit, + }, + sbpir::SBPIR, + util::UUID, + wit_gen::TraceWitness, +}; +use std::{ + collections::HashMap, + fmt::{Debug, Write}, +}; +extern crate regex; + +#[allow(non_snake_case)] +/// User generate PIL code using this function. User needs to supply AST, TraceWitness, and a name +/// string for the circuit. +pub fn chiquito2Pil( + ast: SBPIR, + witness: Option>, + circuit_name: String, +) -> String { + // generate PIL IR. + let pil_ir = compile::(&ast, witness, circuit_name, &None); + + // generate Powdr PIL code. + pil_ir_to_powdr_pil::(pil_ir) +} + +// Convert PIL IR to Powdr PIL code. +pub fn pil_ir_to_powdr_pil(pil_ir: PILCircuit) -> String { + let mut pil = String::new(); // The string to return. + + writeln!( + pil, + "// ===== START OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + + // Namespace is equivalent to a circuit in PIL. + writeln!( + pil, + "constant %NUM_STEPS_{} = {};", + pil_ir.circuit_name.to_uppercase(), + pil_ir.num_steps + ) + .unwrap(); + writeln!( + pil, + "namespace {}(%NUM_STEPS_{});", + pil_ir.circuit_name, + pil_ir.circuit_name.to_uppercase() + ) + .unwrap(); + + // Declare witness columns in PIL. + generate_pil_witness_columns(&mut pil, &pil_ir); + + // Declare fixed columns and their assignments in PIL. + generate_pil_fixed_columns(&mut pil, &pil_ir); + + // generate constraints + for expr in pil_ir.constraints { + // recursively convert expressions to PIL strings + let expr_string = convert_to_pil_expr_string(expr.clone()); + // each constraint is in the format of `constraint = 0` + writeln!(pil, "{} = 0;", expr_string).unwrap(); + } + + // generate lookups + for lookup in pil_ir.lookups { + let (selector, src_dest_tuples) = lookup; + let lookup_selector = selector.annotation(); + let mut lookup_source: Vec = Vec::new(); + let mut lookup_destination: Vec = Vec::new(); + for (src, dest) in src_dest_tuples { + lookup_source.push(src.annotation()); + lookup_destination.push(dest.annotation()); + } + // PIL lookups have the format of `selector { src1, src2, ... srcn } in {dest1, dest2, ..., + // destn}`. + writeln!( + pil, + "{} {{{}}} in {{{}}} ", + lookup_selector, + lookup_source.join(", "), + lookup_destination.join(", ") + ) + .unwrap(); + } + + writeln!( + pil, + "// ===== END OF CIRCUIT: {} =====", + pil_ir.circuit_name + ) + .unwrap(); + writeln!(pil).unwrap(); // Separator row for the circuit. + + pil +} + +#[allow(non_snake_case)] +/// User generate PIL code for super circuit using this function. +/// User needs to supply a Vec for `circuit_names`, the order of which should be the same as +/// the order of calling `sub_circuit()` function. +pub fn chiquitoSuperCircuit2Pil( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> String { + let mut pil = String::new(); // The string to return. + + // Generate PIL IRs for each sub circuit in the super circuit. + let pil_irs = compile_super_circuits( + super_asts, + super_trace_witnesses, + ast_id_to_ir_id_mapping, + circuit_names, + ); + + // Generate Powdr PIL code for each sub circuit. + for pil_ir in pil_irs { + let pil_circuit = pil_ir_to_powdr_pil(pil_ir); + writeln!(pil, "{}", pil_circuit).unwrap(); + } + + pil +} + +fn generate_pil_witness_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_witness.is_empty() { + writeln!(pil, "// === Witness Columns ===").unwrap(); + let mut col_witness = String::from("col witness "); + + let mut col_witness_vars = pil_ir + .col_witness + .iter() + .map(|col| match col { + PILColumn::Advice(_, annotation) => annotation.clone(), + _ => panic!("Witness column should be an advice column."), + }) + .collect::>(); + + // Get unique witness column annotations + col_witness_vars.sort(); + col_witness_vars.dedup(); + col_witness = col_witness + col_witness_vars.join(", ").as_str() + ";"; + writeln!(pil, "{}", col_witness).unwrap(); + } +} + +fn generate_pil_fixed_columns(pil: &mut String, pil_ir: &PILCircuit) { + if !pil_ir.col_fixed.is_empty() { + writeln!( + pil, + "// === Fixed Columns for Signals and Step Type Selectors ===" + ) + .unwrap(); + for (col, assignments) in pil_ir.col_fixed.iter() { + let fixed_name = match col { + PILColumn::Fixed(_, annotation) => annotation.clone(), + _ => panic!("Fixed column should be an advice or fixed column."), + }; + let mut assignments_string = String::new(); + let assignments_vec = assignments + .iter() + .map(|assignment| format!("{:?}", assignment)) + .collect::>(); + write!( + assignments_string, + "{}", + assignments_vec.join(", ").as_str() + ) + .unwrap(); + writeln!(pil, "col fixed {} = [{}];", fixed_name, assignments_string).unwrap(); + } + } +} + +// Convert PIL expression to Powdr PIL string recursively. +fn convert_to_pil_expr_string(expr: PILExpr) -> String { + match expr { + PILExpr::Const(constant) => format!("{:?}", constant), + PILExpr::Sum(sum) => { + let mut expr_string = String::new(); + for (index, expr) in sum.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != sum.len() - 1 { + expr_string += " + "; + } + } + format!("({})", expr_string) + } + PILExpr::Mul(mul) => { + let mut expr_string = String::new(); + for (index, expr) in mul.iter().enumerate() { + expr_string += convert_to_pil_expr_string(expr.clone()).as_str(); + if index != mul.len() - 1 { + expr_string += " * "; + } + } + expr_string.to_string() + } + PILExpr::Neg(neg) => format!("(-{})", convert_to_pil_expr_string(*neg)), + PILExpr::Pow(pow, power) => { + format!("({})^{}", convert_to_pil_expr_string(*pow), power) + } + PILExpr::Query(queriable) => convert_to_pil_queriable_string(queriable), + } +} + +// Convert PIL query to Powdr PIL string recursively. +fn convert_to_pil_queriable_string(query: PILQuery) -> String { + let (col, rot) = query; + let annotation = col.annotation(); + if rot { + format!("{}'", annotation) + } else { + annotation + } +} diff --git a/src/pil/compiler/mod.rs b/src/pil/compiler/mod.rs new file mode 100644 index 00000000..0928f3ae --- /dev/null +++ b/src/pil/compiler/mod.rs @@ -0,0 +1,578 @@ +use crate::{ + field::Field, + pil::ir::powdr_pil::{PILCircuit, PILLookup}, + poly::Expr, + sbpir::{query::Queriable, SBPIR}, + util::{uuid, UUID}, + wit_gen::TraceWitness, +}; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; +extern crate regex; + +pub fn compile( + ast: &SBPIR, + witness: Option>, + circuit_name: String, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILCircuit { + let col_witness = collect_witness_columns(ast); + + // HashMap of fixed column to fixed assignments + let mut col_fixed = HashMap::new(); + + if let Some(fixed_assignments) = &ast.fixed_assignments { + fixed_assignments + .iter() + .for_each(|(queriable, assignments)| { + let uuid = queriable.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + assignments.clone(), + ); + }); + } + + // Get last step instance UUID, so that we can disable transition of that instance + let mut last_step_instance = 0; + + // Insert into col_fixed the map from step type fixed column to vector of {0,1} where 1 means + // the step type is instantiated whereas 0 not. Each vector should have the same length as the + // number of steps. + if !ast.step_types.is_empty() && witness.is_some() { + let step_instances = witness.as_ref().unwrap().step_instances.iter(); + + // Get last step instance, so that we can disable transition of that instance + last_step_instance = step_instances.clone().last().unwrap().step_type_uuid; + + for step_type in ast.step_types.values() { + let step_type_instantiation: Vec = step_instances + .clone() + .map(|step_instance| { + if step_instance.step_type_uuid == step_type.uuid() { + F::ONE + } else { + F::ZERO + } + }) + .collect(); + assert_eq!(step_type_instantiation.len(), ast.num_steps); + let uuid = step_type.uuid(); + col_fixed.insert( + PILColumn::Fixed( + uuid, + clean_annotation(ast.annotations.get(&uuid).unwrap().clone()), + ), + step_type_instantiation, + ); + } + } + + // Create new UUID for ISFIRST and ISLAST. These are fixed columns unique to PIL. + let is_first_uuid = uuid(); + let is_last_uuid = uuid(); + + // ISFIRST and ISLAST are only relevant when there's non zero number of step instances. + let num_step_instances = witness + .as_ref() + .map(|w| w.step_instances.len()) + .unwrap_or(0); + if num_step_instances != 0 { + // 1 for first row and 0 for all other rows; number of rows equals to number of steps + let is_first_assignments = vec![F::ONE] + .into_iter() + .chain(std::iter::repeat(F::ZERO)) + .take(ast.num_steps) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + is_first_assignments, + ); + + // 0 for all rows except the last row, which is 1; number of rows equals to number of steps + let is_last_assignments = std::iter::repeat(F::ZERO) + .take(ast.num_steps - 1) + .chain(std::iter::once(F::ONE)) + .collect(); + col_fixed.insert( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + is_last_assignments, + ); + } + + // Compile step type elements, i.e. constraints, transitions, and lookups. + let (mut constraints, lookups) = compile_steps( + ast, + last_step_instance, + is_last_uuid, + super_circuit_annotations_map, + ); + + // Insert pragma_first_step and pragma_last_step as constraints + if let Some(first_step) = ast.first_step { + // is_first * (1 - first_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_first_uuid, String::from("ISFIRST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + first_step, + clean_annotation(ast.annotations.get(&first_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + if let Some(last_step) = ast.last_step { + // is_last * (1 - last_step) = 0 + constraints.push(PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )), + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed( + last_step, + clean_annotation(ast.annotations.get(&last_step).unwrap().clone()), + ), + false, + )))), + ]), + ])); + } + + PILCircuit { + circuit_name, + num_steps: ast.num_steps, + col_witness, + col_fixed, + constraints, + lookups, + } +} + +pub fn compile_super_circuits( + super_asts: Vec>, + super_trace_witnesses: HashMap>, + ast_id_to_ir_id_mapping: HashMap, + circuit_names: Vec, +) -> Vec> { + assert!(super_asts.len() == circuit_names.len()); + + // Get annotations map for the super circuit, which is a HashMap of object UUID to object + // annotation. + let mut super_circuit_annotations_map: HashMap = HashMap::new(); + + // Loop over each AST. + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + // Create `annotations_map` for each AST, to be added to `super_circuit_annotations_map`. + let mut annotations_map: HashMap = HashMap::new(); + + // First, get AST level annotations. + annotations_map.extend(ast.annotations.clone()); + + // Second, get step level annotations. + for step_type in ast.step_types.values() { + annotations_map.extend(step_type.annotations.clone()); + } + + // Convert annotation to circuit_name.annotation, because this is the general format of + // referring to variables in PIL if there are more than one circuit. + super_circuit_annotations_map.extend(annotations_map.into_iter().map( + |(uuid, annotation)| { + ( + uuid, + format!("{}.{}", circuit_name.clone(), clean_annotation(annotation)), + ) + }, + )); + + // Finally, get annotations for the circuit names. + super_circuit_annotations_map.insert(ast.id, circuit_name.clone()); + } + + // For each AST, find its corresponding TraceWitness. Note that some AST might not have a + // corresponding TraceWitness, so witness is an Option. + let mut pil_irs = Vec::new(); + for (ast, circuit_name) in super_asts.iter().zip(circuit_names.iter()) { + let witness = super_trace_witnesses.get(ast_id_to_ir_id_mapping.get(&ast.id).unwrap()); + + // Create PIL IR + let pil_ir = compile( + ast, + witness.cloned(), + circuit_name.clone(), + &Some(&super_circuit_annotations_map), + ); + + pil_irs.push(pil_ir); + } + + pil_irs +} + +fn collect_witness_columns(ast: &SBPIR) -> Vec { + let mut col_witness = Vec::new(); + + // Collect internal signals to witness columns. + col_witness.extend( + ast.step_types + .values() + .flat_map(|step_type| { + step_type + .signals + .iter() + .map(|signal| { + PILColumn::Advice(signal.uuid(), clean_annotation(signal.annotation())) + }) + .collect::>() + }) + .collect::>(), + ); + + // Collect forward signals to witness columns. + col_witness.extend( + ast.forward_signals + .iter() + .map(|forward_signal| { + PILColumn::Advice( + forward_signal.uuid(), + clean_annotation(forward_signal.annotation()), + ) + }) + .collect::>(), + ); + + // Collect shared signals to witness columns. + col_witness.extend( + ast.shared_signals + .iter() + .map(|shared_signal| { + PILColumn::Advice( + shared_signal.uuid(), + clean_annotation(shared_signal.annotation()), + ) + }) + .collect::>(), + ); + + col_witness +} + +fn compile_steps( + ast: &SBPIR, + last_step_instance: UUID, + is_last_uuid: UUID, + super_circuit_annotations_map: &Option<&HashMap>, +) -> (Vec>, Vec) { + // transitions and constraints all become constraints in PIL + let mut constraints = Vec::new(); + let mut lookups = Vec::new(); + + if !ast.step_types.is_empty() { + ast.step_types.values().for_each(|step_type| { + // Create constraint statements. + constraints.extend( + step_type + .constraints + .iter() + .map(|constraint| { + PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + constraint.expr.clone(), + super_circuit_annotations_map, + ), + ]) + }) + .collect::>>(), + ); + + // There's no distinction between constraint and transition in PIL + // However, we do need to identify constraints with rotation in the last row + // and disable them + constraints.extend( + step_type + .transition_constraints + .iter() + .map(|transition| { + let res = PILExpr::Mul(vec![ + PILExpr::Query(( + PILColumn::Fixed( + step_type.uuid(), + clean_annotation(step_type.name()), + ), + false, + )), + chiquito_expr_to_pil_expr( + transition.expr.clone(), + super_circuit_annotations_map, + ), + ]); + if step_type.uuid() == last_step_instance { + PILExpr::Mul(vec![ + PILExpr::Sum(vec![ + PILExpr::Const(F::ONE), + PILExpr::Neg(Box::new(PILExpr::Query(( + PILColumn::Fixed(is_last_uuid, String::from("ISLAST")), + false, + )))), + ]), + res, + ]) + } else { + res + } + }) + .collect::>>(), + ); + + lookups.extend( + step_type + .lookups + .iter() + .map(|lookup| { + ( + PILColumn::Fixed(step_type.uuid(), clean_annotation(step_type.name())), + lookup + .exprs + .iter() + .map(|(lhs, rhs)| { + ( + chiquito_lookup_column_to_pil_column( + lhs.expr.clone(), + super_circuit_annotations_map, + ), + chiquito_lookup_column_to_pil_column( + rhs.clone(), + super_circuit_annotations_map, + ), + ) + }) + .collect::>(), + ) + }) + .collect::>(), + ); + }); + } + + (constraints, lookups) +} + +// Convert lookup columns (src and dest) in Chiquito to PIL column. Note that Chiquito lookup +// columns have to be Expr::Query type. +fn chiquito_lookup_column_to_pil_column( + src: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILColumn { + match src { + Expr::Query(queriable) => { + chiquito_queriable_to_pil_query(queriable, super_circuit_annotations_map).0 + } + _ => panic!("Lookup source is not queriable."), + } +} + +// PIL expression and constraint +#[derive(Clone)] +pub enum PILExpr { + Const(F), + Sum(Vec>), + Mul(Vec>), + Neg(Box>), + Pow(Box>, u32), + Query(PILQuery), +} + +fn chiquito_expr_to_pil_expr( + expr: Expr>, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILExpr { + match expr { + Expr::Const(constant) => PILExpr::Const(constant), + Expr::Sum(sum) => { + let mut pil_sum = Vec::new(); + for expr in sum { + pil_sum.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Sum(pil_sum) + } + Expr::Mul(mul) => { + let mut pil_mul = Vec::new(); + for expr in mul { + pil_mul.push(chiquito_expr_to_pil_expr( + expr, + super_circuit_annotations_map, + )); + } + PILExpr::Mul(pil_mul) + } + Expr::Neg(neg) => PILExpr::Neg(Box::new(chiquito_expr_to_pil_expr( + *neg, + super_circuit_annotations_map, + ))), + Expr::Pow(pow, power) => PILExpr::Pow( + Box::new(chiquito_expr_to_pil_expr( + *pow, + super_circuit_annotations_map, + )), + power, + ), + Expr::Query(queriable) => PILExpr::Query(chiquito_queriable_to_pil_query( + queriable, + super_circuit_annotations_map, + )), + Expr::Halo2Expr(_) => { + panic!("Halo2 native expression not supported by PIL backend.") + } + Expr::MI(_) => { + panic!("MI not supported by PIL backend.") + } + } +} + +pub type PILQuery = (PILColumn, bool); // column, rotation + +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum PILColumn { + Advice(UUID, String), // UUID, annotation + Fixed(UUID, String), +} + +impl PILColumn { + pub fn uuid(&self) -> UUID { + match self { + PILColumn::Advice(uuid, _) => *uuid, + PILColumn::Fixed(uuid, _) => *uuid, + } + } + + pub fn annotation(&self) -> String { + match self { + PILColumn::Advice(_, annotation) => annotation.clone(), + PILColumn::Fixed(_, annotation) => annotation.clone(), + } + } +} + +pub fn clean_annotation(annotation: String) -> String { + annotation.replace(' ', "_") +} + +// Convert queriable to PIL column recursively. Major differences are: 1. PIL doesn't distinguish +// internal, forward, or shared columns as they are all advice; 2. PIL only supports the next +// rotation, so there's no previous or arbitrary rotation. +fn chiquito_queriable_to_pil_query( + query: Queriable, + super_circuit_annotations_map: &Option<&HashMap>, +) -> PILQuery { + match query { + Queriable::Internal(s) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + false, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + false, + ) + } + } + Queriable::Forward(s, rot) => { + if super_circuit_annotations_map.is_none() { + ( + PILColumn::Advice(s.uuid(), clean_annotation(s.annotation())), + rot, + ) + } else { + let annotation = super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap(); + ( + PILColumn::Advice(s.uuid(), clean_annotation(annotation.clone())), + rot, + ) + } + } + Queriable::Shared(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Advice(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Advice(s.uuid(), annotation), true) + } else { + panic!( + "PIL backend does not support shared signal with rotation other than 0 or 1." + ) + } + } + Queriable::Fixed(s, rot) => { + let annotation = if super_circuit_annotations_map.is_none() { + clean_annotation(s.annotation()) + } else { + super_circuit_annotations_map + .as_ref() + .unwrap() + .get(&s.uuid()) + .unwrap() + .clone() + }; + if rot == 0 { + (PILColumn::Fixed(s.uuid(), annotation), false) + } else if rot == 1 { + (PILColumn::Fixed(s.uuid(), annotation), true) + } else { + panic!("PIL backend does not support fixed signal with rotation other than 0 or 1.") + } + } + Queriable::StepTypeNext(s) => ( + PILColumn::Fixed(s.uuid(), clean_annotation(s.annotation())), + true, + ), + Queriable::Halo2AdviceQuery(_, _) => { + panic!("Halo2 native advice query not supported by PIL backend.") + } + Queriable::Halo2FixedQuery(_, _) => { + panic!("Halo2 native fixed query not supported by PIL backend.") + } + Queriable::_unaccessible(_) => todo!(), + } +} diff --git a/src/pil/ir/mod.rs b/src/pil/ir/mod.rs new file mode 100644 index 00000000..d57b058b --- /dev/null +++ b/src/pil/ir/mod.rs @@ -0,0 +1 @@ +pub mod powdr_pil; diff --git a/src/pil/ir/powdr_pil.rs b/src/pil/ir/powdr_pil.rs new file mode 100644 index 00000000..576d147b --- /dev/null +++ b/src/pil/ir/powdr_pil.rs @@ -0,0 +1,18 @@ +use crate::pil::compiler::{PILColumn, PILExpr, PILQuery}; +use std::collections::HashMap; +extern crate regex; + +// PIL circuit IR +pub struct PILCircuit { + pub circuit_name: String, + pub num_steps: usize, + pub col_witness: Vec, + pub col_fixed: HashMap>, // column -> assignments + pub constraints: Vec>, + pub lookups: Vec, +} + +// lookup in PIL is the format of selector {src1, src2, ..., srcn} -> {dst1, dst2, ..., dstn} +// PILLookup is a tuple of (selector, Vec) tuples, where selector is converted from +// Chiquito step type to fixed column +pub type PILLookup = (PILColumn, Vec<(PILColumn, PILColumn)>); diff --git a/src/pil/mod.rs b/src/pil/mod.rs new file mode 100644 index 00000000..aab5b126 --- /dev/null +++ b/src/pil/mod.rs @@ -0,0 +1,3 @@ +pub mod backend; +pub mod compiler; +pub mod ir; diff --git a/src/plonkish/backend/halo2.rs b/src/plonkish/backend/halo2.rs index 7cc0082d..d8dfd6d2 100644 --- a/src/plonkish/backend/halo2.rs +++ b/src/plonkish/backend/halo2.rs @@ -215,25 +215,6 @@ impl + Hash> ChiquitoHalo2 { Ok(()) } - fn instance(&self, witness: &Assignments) -> Vec { - let mut instance_values = Vec::new(); - for (column, rotation) in &self.circuit.exposed { - let values = witness - .get(column) - .unwrap_or_else(|| panic!("exposed column not found: {}", column.annotation)); - - if let Some(value) = values.get(*rotation as usize) { - instance_values.push(*value); - } else { - panic!( - "assignment index out of bounds for column: {}", - column.annotation - ); - } - } - instance_values - } - fn annotate_circuit(&self, region: &mut Region) { for column in self.circuit.columns.iter() { match column.ctype { @@ -379,7 +360,7 @@ impl + Hash> ChiquitoHalo2Circuit { pub fn instance(&self) -> Vec> { if !self.compiled.circuit.exposed.is_empty() { if let Some(witness) = &self.witness { - return vec![self.compiled.instance(witness)]; + return vec![self.compiled.circuit.instance(witness)]; } } Vec::new() @@ -444,7 +425,7 @@ impl + Hash> ChiquitoHalo2SuperCircuit { for sub_circuit in &self.sub_circuits { if !sub_circuit.circuit.exposed.is_empty() { - let instance_values = sub_circuit.instance( + let instance_values = sub_circuit.circuit.instance( self.witness .get(&sub_circuit.ir_id) .expect("No matching witness found for given UUID."), diff --git a/src/plonkish/backend/hyperplonk.rs b/src/plonkish/backend/hyperplonk.rs new file mode 100644 index 00000000..622c1c35 --- /dev/null +++ b/src/plonkish/backend/hyperplonk.rs @@ -0,0 +1,379 @@ +use crate::{ + plonkish::ir::{assignments::Assignments, Circuit, Column, ColumnType, PolyExpr}, + util::UUID, +}; +use halo2_proofs::arithmetic::Field; +use plonkish_backend::{ + backend::{PlonkishCircuit, PlonkishCircuitInfo}, + util::expression::{rotate::Rotation, Expression, Query}, +}; +use std::{collections::HashMap, hash::Hash}; + +// get max phase number + 1 to get number of phases +// for example, if the phases slice is [0, 1, 0, 1, 2, 2], then the output will be 3 +fn num_phases(phases: &[usize]) -> usize { + phases.iter().max().copied().unwrap_or_default() + 1 +} + +// get number of columns for each phase given a vector of phases +// for example, if the phases slice is [0, 1, 0, 1, 2, 2], then the output vector will be +// [2, 2, 2] +fn num_by_phase(phases: &[usize]) -> Vec { + phases.iter().copied().fold( + vec![0usize; num_phases(phases)], + |mut num_by_phase, phase| { + num_by_phase[phase] += 1; + num_by_phase + }, + ) +} + +// This function maps each element in the phases slice to its index within the circuit, given an +// offset For example, if the phases slice is [0, 1, 0, 1, 2, 2], and the offset is 3, then the +// output vector will be [3, 5, 4, 6, 7, 8], i.e. [3+0+0, 3+2+0, 3+0+1, 3+2+1, 3+4+0, 3+4+1], i.e. +// [offset+phase_offset+index] +fn idx_order_by_phase(phases: &[usize], offset: usize) -> Vec { + phases + .iter() + .copied() + .scan(phase_offsets(phases), |state, phase| { + let index = state[phase]; + state[phase] += 1; + Some(offset + index) + }) + .collect() +} + +// get vector of advice column phases +fn advice_phases(circuit: &Circuit) -> Vec { + circuit + .columns + .iter() + .filter(|column| column.ctype == ColumnType::Advice) + .map(|column| column.phase) + .collect::>() +} + +// This function computes the offsets for each phase. +// For example, if the phases slice is [0, 1, 0, 1, 2, 2], then the output vector will be +// [0, 2, 4]. +fn phase_offsets(phases: &[usize]) -> Vec { + num_by_phase(phases) + .into_iter() + .scan(0, |state, num| { + let offset = *state; + *state += num; + Some(offset) + }) + .collect() +} + +pub struct ChiquitoHyperPlonkCircuit { + circuit: ChiquitoHyperPlonk, + assignments: Option>, +} + +pub struct ChiquitoHyperPlonk { + k: usize, + instances: Vec>, /* outter vec has length 1, inner vec has length equal to number of + * exposed signals */ + chiquito_ir: Circuit, + num_witness_polys: Vec, + all_uuids: Vec, // the same order as self.chiquito_ir.columns + fixed_uuids: Vec, // the same order as self.chiquito_ir.columns + advice_uuids: Vec, // the same order as self.chiquito_ir.columns + advice_uuids_by_phase: HashMap>, +} + +impl + Hash> ChiquitoHyperPlonk { + fn new(k: usize, circuit: Circuit) -> Self { + // get all column uuids + let all_uuids = circuit + .columns + .iter() + .map(|column| column.id) + .collect::>(); + + // get fixed column uuids + let fixed_uuids = circuit + .columns + .iter() + .filter(|column| column.ctype == ColumnType::Fixed) + .map(|column| column.id) + .collect::>(); + + // get advice column uuids (including step selectors) + let advice_uuids = circuit + .columns + .iter() + .filter(|column| column.ctype == ColumnType::Advice) + .map(|column| column.id) + .collect::>(); + + // check that length of all uuid vectors equals length of all columns + assert_eq!( + fixed_uuids.len() + advice_uuids.len(), + circuit.columns.len() + ); + + // get phase number for all advice columns + let advice_phases = advice_phases(&circuit); + // get number of witness polynomials for each phase + let num_witness_polys = num_by_phase(&advice_phases); + + // given non_selector_advice_phases and non_selector_advice_uuids, which have equal lengths, + // create hashmap of phase to vector of uuids if phase doesn't exist in map, create + // a new vector and insert it into map if phase exists in map, insert the uuid to + // the vector associated with the phase + assert_eq!(advice_phases.len(), advice_uuids.len()); + let advice_uuids_by_phase = advice_phases.iter().zip(advice_uuids.iter()).fold( + HashMap::new(), + |mut map: HashMap>, (phase, uuid)| { + map.entry(*phase).or_default().push(*uuid); + map + }, + ); + + Self { + k, + instances: Vec::default(), + chiquito_ir: circuit, + num_witness_polys, + all_uuids, + fixed_uuids, + advice_uuids, + advice_uuids_by_phase, + } + } + + fn set_instance(&mut self, instance: Vec>) { + self.instances = instance; + } +} + +impl + Hash> ChiquitoHyperPlonkCircuit { + pub fn new(k: usize, circuit: Circuit) -> Self { + let chiquito_hyper_plonk = ChiquitoHyperPlonk::new(k, circuit); + Self { + circuit: chiquito_hyper_plonk, + assignments: None, + } + } + + pub fn set_assignment(&mut self, assignments: Assignments) { + let instances = vec![self.circuit.chiquito_ir.instance(&assignments)]; + self.circuit.set_instance(instances); + self.assignments = Some(assignments); + } +} + +// given column uuid and the vector of all column uuids, get the index or position of the uuid +// has no offset +fn column_idx(column_uuid: UUID, column_uuids: &[UUID]) -> usize { + column_uuids + .iter() + .position(|&uuid| uuid == column_uuid) + .unwrap() +} + +impl + Hash> PlonkishCircuit for ChiquitoHyperPlonkCircuit { + fn circuit_info_without_preprocess( + &self, + ) -> Result, plonkish_backend::Error> { + // there's only one instance column whose length is equal to the number of exposed signals + // in chiquito circuit `num_instances` is a vector of length 1, because we only have + // one instance column + let num_instances = self.circuit.instances.iter().map(Vec::len).collect(); + + // a vector of zero vectors, each zero vector with 2^k length + // number of preprocess is equal to number of fixed columns + let preprocess_polys = + vec![vec![F::ZERO; 1 << self.circuit.k]; self.circuit.fixed_uuids.len()]; + + let advice_idx = self.circuit.advice_idx(); + let constraints: Vec> = self + .circuit + .chiquito_ir + .polys + .iter() + .map(|poly| { + self.circuit + .convert_expression(poly.expr.clone(), &advice_idx) + }) + .collect(); + + let lookups = self + .circuit + .chiquito_ir + .lookups + .iter() + .map(|lookup| { + lookup + .exprs + .iter() + .map(|(input, table)| { + ( + self.circuit.convert_expression(input.clone(), &advice_idx), + self.circuit.convert_expression(table.clone(), &advice_idx), + ) + }) + .collect() + }) + .collect(); + + let max_degree = constraints + .iter() + .map(|constraint| constraint.degree()) + .max(); + + Ok(PlonkishCircuitInfo { + k: self.circuit.k, + num_instances, + preprocess_polys, + num_witness_polys: self.circuit.num_witness_polys.clone(), + num_challenges: vec![0; self.circuit.num_witness_polys.len()], + constraints, + lookups, + permutations: Default::default(), // Chiquito doesn't have permutations + max_degree, + }) + } + + // preprocess fixed assignments + fn circuit_info( + &self, + ) -> Result, plonkish_backend::Error> { + let mut circuit_info = self.circuit_info_without_preprocess()?; + // make sure all fixed assignments are for fixed column type + self.circuit + .chiquito_ir + .fixed_assignments + .iter() + .for_each(|(column, _)| match column.ctype { + ColumnType::Fixed => (), + _ => panic!("fixed assignments must be for fixed column type"), + }); + // get assignments Vec by looking up from fixed_assignments and reorder assignment + // vectors according to self.fixed_uuids. finally bind all Vec to a Vec>. + // here, get Vec from fixed_assigments: HashMap> by looking up the Column + // with uuid + let fixed_assignments = self + .circuit + .fixed_uuids + .iter() + .map(|uuid| { + self.circuit + .chiquito_ir + .fixed_assignments + .get( + &self.circuit.chiquito_ir.columns + [column_idx(*uuid, &self.circuit.all_uuids)], + ) + .unwrap() + .clone() + }) + .collect::>>(); + + circuit_info.preprocess_polys = fixed_assignments; + + Ok(circuit_info) + } + + fn instances(&self) -> &[Vec] { + &self.circuit.instances + } + + fn synthesize( + &self, + phase: usize, + _challenges: &[F], + ) -> Result>, plonkish_backend::Error> { + let assignments = self.assignments.clone().unwrap(); + + let advice_assignments = self + .circuit + .advice_uuids_by_phase + .get(&phase) + .expect("synthesize: phase not found") + .iter() + .map(|uuid| { + assignments + .get( + &self.circuit.chiquito_ir.columns + [column_idx(*uuid, &self.circuit.all_uuids)], + ) + .unwrap() + .clone() + }) + .collect::>>(); + Ok(advice_assignments) + } +} + +impl ChiquitoHyperPlonk { + fn advice_idx(self: &ChiquitoHyperPlonk) -> Vec { + let advice_offset = self.fixed_uuids.len(); + idx_order_by_phase(&advice_phases(&self.chiquito_ir), advice_offset) + } + + fn convert_query( + self: &ChiquitoHyperPlonk, + column: Column, + rotation: i32, + advice_indx: &[usize], + ) -> Expression { + // if column type is fixed, query column will be determined by column_idx function and + // self.fixed_uuids + // if column type is advice, query column will be + // determined by column_idx function and self.advice_uuids + // advice columns come after fixed columns + if column.ctype == ColumnType::Fixed { + let column_idx = column_idx(column.id, &self.fixed_uuids); + Query::new(column_idx, Rotation(rotation)).into() + } else if column.ctype == ColumnType::Advice { + // advice_idx already takes into account of the offset of fixed columns + let column_idx = advice_indx[column_idx(column.id, &self.advice_uuids)]; + Query::new(column_idx, Rotation(rotation)).into() + } else { + panic!("convert_query: column type not supported") + } + } + + fn convert_expression( + self: &ChiquitoHyperPlonk, + poly: PolyExpr, + advice_idx: &Vec, + ) -> Expression { + match poly { + PolyExpr::Const(constant) => Expression::Constant(constant), + PolyExpr::Query((column, rotation, _)) => { + self.convert_query(column, rotation, advice_idx) + } + PolyExpr::Sum(expressions) => { + let mut iter = expressions.iter(); + let first = self.convert_expression(iter.next().unwrap().clone(), advice_idx); + iter.fold(first, |acc, expression| { + acc + self.convert_expression(expression.clone(), advice_idx) + }) + } + PolyExpr::Mul(expressions) => { + let mut iter = expressions.iter(); + let first = self.convert_expression(iter.next().unwrap().clone(), advice_idx); + iter.fold(first, |acc, expression| { + acc * self.convert_expression(expression.clone(), advice_idx) + }) + } + PolyExpr::Neg(expression) => -self.convert_expression(*expression, advice_idx), /* might need to convert to Expression::Negated */ + PolyExpr::Pow(expression, pow) => { + if pow == 0 { + Expression::Constant(F::ONE) + } else { + let expression = self.convert_expression(*expression, advice_idx); + (1..pow).fold(expression.clone(), |acc, _| acc * expression.clone()) + } + } + PolyExpr::Halo2Expr(_) => panic!("halo2 expressions not supported"), + PolyExpr::MI(_) => panic!("MI expressions not supported"), + } + } +} diff --git a/src/plonkish/backend/mod.rs b/src/plonkish/backend/mod.rs index 84f7c9f2..8d078b29 100644 --- a/src/plonkish/backend/mod.rs +++ b/src/plonkish/backend/mod.rs @@ -1,2 +1,3 @@ pub mod halo2; +pub mod hyperplonk; pub mod plaf; diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index 1e40c710..b36a2d3a 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -567,9 +567,67 @@ fn add_halo2_columns(unit: &mut CompilationUnit, ast: &astCircu } #[cfg(test)] -mod tests { - use super::*; - use halo2curves::bn256::Fr; +mod test { + use halo2_proofs::{halo2curves::bn256::Fr, plonk::Any}; + + use super::{cell_manager::SingleRowCellManager, step_selector::SimpleStepSelectorBuilder, *}; + + #[test] + fn test_compiler_config_initialization() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + + let config = config(cell_manager.clone(), step_selector_builder.clone()); + + assert_eq!( + format!("{:#?}", config.cell_manager), + format!("{:#?}", cell_manager) + ); + assert_eq!( + format!("{:#?}", config.step_selector_builder), + format!("{:#?}", step_selector_builder) + ); + } + + #[test] + fn test_compile() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + let config = config(cell_manager, step_selector_builder); + + let mock_ast_circuit = astCircuit::::default(); + + let (circuit, assignment_generator) = compile(config, &mock_ast_circuit); + + assert_eq!(circuit.columns.len(), 1); + assert_eq!(circuit.exposed.len(), 0); + assert_eq!(circuit.polys.len(), 0); + assert_eq!(circuit.lookups.len(), 0); + assert_eq!(circuit.fixed_assignments.len(), 1); + assert_eq!(circuit.ast_id, mock_ast_circuit.id); + + assert!(assignment_generator.is_none()); + } + + #[test] + fn test_compile_phase1() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + let config = config(cell_manager, step_selector_builder); + + let mock_ast_circuit = astCircuit::::default(); + + let (unit, assignment_generator) = compile_phase1(config, &mock_ast_circuit); + + assert_eq!(unit.columns.len(), 1); + assert_eq!(unit.exposed.len(), 0); + assert_eq!(unit.polys.len(), 0); + assert_eq!(unit.lookups.len(), 0); + assert_eq!(unit.fixed_assignments.len(), 0); + assert_eq!(unit.ast_id, mock_ast_circuit.id); + + assert!(assignment_generator.is_none()); + } #[test] #[should_panic] @@ -578,4 +636,19 @@ mod tests { compile_phase2(&mut unit); } + + #[test] + fn test_add_default_columns() { + let mock_ast_circuit = astCircuit::::default(); + + let mut unit = CompilationUnit::from(&mock_ast_circuit); + add_default_columns(&mut unit); + + assert_eq!(unit.columns.len(), 1); + assert_eq!(unit.exposed.len(), 0); + assert_eq!(unit.polys.len(), 0); + assert_eq!(unit.lookups.len(), 0); + assert_eq!(unit.fixed_assignments.len(), 0); + assert_eq!(unit.ast_id, mock_ast_circuit.id); + } } diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index b4d88a6f..390c1c80 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -202,10 +202,10 @@ impl StepSelectorBuilder for LogNSelectorBuilder { let n_step_types = unit.step_types.len() as u64; let n_cols = (n_step_types as f64 + 1.0).log2().ceil() as u64; - + println!("n_step_types = {}, n_cols = {}", n_step_types, n_cols); let mut annotation; for index in 0..n_cols { - annotation = format!("'binary selector column {}'", index); + annotation = format!("'step selector for binary column {}'", index); let column = Column::advice(annotation.clone(), 0); selector.columns.push(column.clone()); @@ -258,7 +258,7 @@ fn other_step_type(unit: &CompilationUnit, uuid: UUID) -> Option(); + assert_eq!(unit.selector.columns.len(), 0); + assert_eq!(unit.selector.selector_expr.len(), 0); + assert_eq!(unit.selector.selector_expr_not.len(), 0); + assert_eq!(unit.selector.selector_assignment.len(), 0); + } + + #[test] + fn test_select_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let constraint = PolyExpr::Const(Fr::ONE); + + let step_uuid = step_type.uuid(); + let selector_expr = selector + .selector_expr + .get(&step_uuid) + .expect("Step not found") + .clone(); + let expected_expr = PolyExpr::Mul(vec![selector_expr, constraint.clone()]); + + assert_eq!( + format!("{:#?}", selector.select(step_uuid, &constraint)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_next_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let step_uuid = step_type.uuid(); + let step_height = 1; + let expected_expr = selector + .selector_expr + .get(&step_uuid) + .expect("Step not found") + .clone() + .rotate(step_height); + + assert_eq!( + format!("{:#?}", selector.next_expr(step_uuid, step_height as u32)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_unselect_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let step_uuid = step_type.uuid(); + let expected_expr = selector + .selector_expr_not + .get(&step_uuid) + .expect("Step not found") + .clone(); + + assert_eq!( + format!("{:#?}", selector.unselect(step_uuid)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_simple_step_selector_builder() { + let builder = SimpleStepSelectorBuilder {}; + let mut unit = mock_compilation_unit::(); + + add_step_types_to_unit(&mut unit, 2); + builder.build(&mut unit); + assert_common_tests(&unit, 2); + } + #[test] fn test_log_n_selector_builder_3_step_types() { let builder = LogNSelectorBuilder {}; diff --git a/src/plonkish/ir/assignments.rs b/src/plonkish/ir/assignments.rs index 9f605d7a..60a91a48 100644 --- a/src/plonkish/ir/assignments.rs +++ b/src/plonkish/ir/assignments.rs @@ -132,8 +132,12 @@ impl AssignmentGenerator { } } + pub fn generate_trace_witness(&self, args: TraceArgs) -> TraceWitness { + self.trace_gen.generate(args) + } + pub fn generate(&self, args: TraceArgs) -> Assignments { - let witness = self.trace_gen.generate(args); + let witness = self.generate_trace_witness(args); self.generate_with_witness(witness) } diff --git a/src/plonkish/ir/mod.rs b/src/plonkish/ir/mod.rs index d012f883..a332a1bc 100644 --- a/src/plonkish/ir/mod.rs +++ b/src/plonkish/ir/mod.rs @@ -36,7 +36,28 @@ impl Debug for Circuit { } } -#[derive(Clone, Debug, Hash)] +impl Circuit { + pub(crate) fn instance(&self, witness: &Assignments) -> Vec { + let mut instance_values = Vec::new(); + for (column, rotation) in &self.exposed { + let values = witness + .get(column) + .unwrap_or_else(|| panic!("exposed column not found: {}", column.annotation)); + + if let Some(value) = values.get(*rotation as usize) { + instance_values.push(value.clone()); + } else { + panic!( + "assignment index out of bounds for column: {}", + column.annotation + ); + } + } + instance_values + } +} + +#[derive(Clone, Debug, Hash, PartialEq)] pub enum ColumnType { Advice, Fixed, diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index 4398bc31..cbf8b650 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,15 +1,17 @@ -use std::{collections::HashMap, hash::Hash, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; -use crate::{field::Field, util::UUID, wit_gen::TraceWitness}; +use crate::{field::Field, sbpir::SBPIR, util::UUID, wit_gen::TraceWitness}; use super::{ assignments::{AssignmentGenerator, Assignments}, Circuit, }; +#[derive(Debug)] pub struct SuperCircuit { sub_circuits: Vec>, mapping: MappingGenerator, + sub_circuit_asts: Vec>, } impl Default for SuperCircuit { @@ -17,6 +19,7 @@ impl Default for SuperCircuit { Self { sub_circuits: Default::default(), mapping: Default::default(), + sub_circuit_asts: Default::default(), } } } @@ -29,6 +32,30 @@ impl SuperCircuit { pub fn get_mapping(&self) -> MappingGenerator { self.mapping.clone() } + + // Needed for the PIL backend. + pub fn add_sub_circuit_ast(&mut self, sub_circuit_ast: SBPIR) { + self.sub_circuit_asts.push(sub_circuit_ast); + } + + // Mapping from AST id to IR id is needed for the PIL backend to match TraceWitness, which has + // IR id, to AST. + pub fn get_ast_id_to_ir_id_mapping(&self) -> HashMap { + let mut ast_id_to_ir_id_mapping: HashMap = HashMap::new(); + self.sub_circuits.iter().for_each(|circuit| { + let ir_id = circuit.id; + let ast_id = circuit.ast_id; + ast_id_to_ir_id_mapping.insert(ast_id, ir_id); + }); + ast_id_to_ir_id_mapping + } +} + +// Needed for the PIL backend. +impl SuperCircuit { + pub fn get_super_asts(&self) -> Vec> { + self.sub_circuit_asts.clone() + } } impl SuperCircuit { @@ -47,22 +74,30 @@ impl SuperCircuit { } pub type SuperAssignments = HashMap>; +pub type SuperTraceWitness = HashMap>; +#[derive(Clone)] pub struct MappingContext { assignments: SuperAssignments, + trace_witnesses: SuperTraceWitness, } -impl Default for MappingContext { +impl Default for MappingContext { fn default() -> Self { Self { assignments: Default::default(), + trace_witnesses: Default::default(), } } } impl MappingContext { pub fn map(&mut self, gen: &AssignmentGenerator, args: TraceArgs) { - self.assignments.insert(gen.uuid(), gen.generate(args)); + let trace_witness = gen.generate_trace_witness(args); + self.trace_witnesses + .insert(gen.uuid(), trace_witness.clone()); + self.assignments + .insert(gen.uuid(), gen.generate_with_witness(trace_witness)); } pub fn map_with_witness( @@ -77,6 +112,10 @@ impl MappingContext { pub fn get_super_assignments(self) -> SuperAssignments { self.assignments } + + pub fn get_trace_witnesses(self) -> SuperTraceWitness { + self.trace_witnesses + } } pub type Mapping = dyn Fn(&mut MappingContext, MappingArgs) + 'static; @@ -93,6 +132,12 @@ impl Clone for MappingGenerator { } } +impl std::fmt::Debug for MappingGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MappingGenerator") + } +} + impl Default for MappingGenerator { fn default() -> Self { Self { @@ -113,4 +158,179 @@ impl MappingGenerator { ctx.get_super_assignments() } + + // Needed for the PIL backend. + pub fn generate_super_trace_witnesses(&self, args: MappingArgs) -> SuperTraceWitness { + let mut ctx = MappingContext::default(); + + (self.mapping)(&mut ctx, args); + + ctx.get_trace_witnesses() + } +} + +#[cfg(test)] +mod test { + use halo2_proofs::halo2curves::bn256::Fr; + + use crate::{ + plonkish::{ + compiler::{cell_manager::Placement, step_selector::StepSelector}, + ir::Column, + }, + util::uuid, + wit_gen::{AutoTraceGenerator, TraceGenerator}, + }; + + use super::*; + + #[test] + fn test_default() { + let super_circuit: SuperCircuit = Default::default(); + + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits), + format!("{:#?}", Vec::>::default()) + ); + assert_eq!( + format!("{:#?}", super_circuit.mapping), + format!("{:#?}", MappingGenerator::::default()) + ); + } + + #[test] + fn test_add_sub_circuit() { + let mut super_circuit: SuperCircuit = Default::default(); + + fn simple_circuit() -> Circuit { + let columns = vec![Column::advice('a', 0)]; + let exposed = vec![(Column::advice('a', 0), 2)]; + let polys = vec![]; + let lookups = vec![]; + let fixed_assignments = Default::default(); + + Circuit { + columns, + exposed, + polys, + lookups, + fixed_assignments, + id: uuid(), + ast_id: uuid(), + } + } + + let sub_circuit = simple_circuit(); + + super_circuit.add_sub_circuit(sub_circuit.clone()); + + assert_eq!(super_circuit.sub_circuits.len(), 1); + assert_eq!(super_circuit.sub_circuits[0].id, sub_circuit.id); + assert_eq!(super_circuit.sub_circuits[0].ast_id, sub_circuit.ast_id); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].columns), + format!("{:#?}", sub_circuit.columns) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].exposed), + format!("{:#?}", sub_circuit.exposed) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].polys), + format!("{:#?}", sub_circuit.polys) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].lookups), + format!("{:#?}", sub_circuit.lookups) + ); + } + + #[test] + fn test_get_sub_circuits() { + fn simple_circuit() -> Circuit { + let columns = vec![Column::advice('a', 0)]; + let exposed = vec![(Column::advice('a', 0), 2)]; + let polys = vec![]; + let lookups = vec![]; + let fixed_assignments = Default::default(); + + Circuit { + columns, + exposed, + polys, + lookups, + fixed_assignments, + id: uuid(), + ast_id: uuid(), + } + } + + let super_circuit: SuperCircuit = SuperCircuit { + sub_circuits: vec![simple_circuit()], + mapping: Default::default(), + sub_circuit_asts: Default::default(), + }; + + let sub_circuits = super_circuit.get_sub_circuits(); + + assert_eq!(sub_circuits.len(), 1); + assert_eq!(sub_circuits[0].id, super_circuit.sub_circuits[0].id); + } + + #[test] + fn test_mapping_context_default() { + let ctx = MappingContext::::default(); + + assert_eq!( + format!("{:#?}", ctx.assignments), + format!("{:#?}", SuperAssignments::::default()) + ); + } + + fn simple_assignment_generator() -> AssignmentGenerator { + AssignmentGenerator::new( + vec![Column::advice('a', 0)], + Placement { + forward: HashMap::new(), + shared: HashMap::new(), + fixed: HashMap::new(), + steps: HashMap::new(), + columns: vec![], + base_height: 0, + }, + StepSelector::default(), + TraceGenerator::default(), + AutoTraceGenerator::default(), + 1, + uuid(), + ) + } + + #[test] + fn test_mapping_context_map() { + let mut ctx = MappingContext::::default(); + + assert_eq!(ctx.assignments.len(), 0); + + let gen = simple_assignment_generator(); + + ctx.map(&gen, ()); + + assert_eq!(ctx.assignments.len(), 1); + } + + #[test] + fn test_mapping_context_map_with_witness() { + let mut ctx = MappingContext::::default(); + + let gen = simple_assignment_generator(); + + let witness = TraceWitness:: { + step_instances: vec![], + }; + + ctx.map_with_witness(&gen, witness); + + assert_eq!(ctx.assignments.len(), 1); + } } diff --git a/src/poly/mielim.rs b/src/poly/mielim.rs index 190cf08c..63b0a554 100644 --- a/src/poly/mielim.rs +++ b/src/poly/mielim.rs @@ -67,7 +67,7 @@ fn mi_elimination_recursive< #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{ poly::{mielim::mi_elimination, Expr}, diff --git a/src/poly/mod.rs b/src/poly/mod.rs index fbc61bd0..0f8bb296 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -97,10 +97,10 @@ impl Expr { Expr::Const(v) => Some(*v), Expr::Sum(ses) => ses .iter() - .fold(Some(F::ZERO), |acc, se| Some(acc? + se.eval(assignments)?)), + .try_fold(F::ZERO, |acc, se| Some(acc + se.eval(assignments)?)), Expr::Mul(ses) => ses .iter() - .fold(Some(F::ONE), |acc, se| Some(acc? * se.eval(assignments)?)), + .try_fold(F::ONE, |acc, se| Some(acc * se.eval(assignments)?)), Expr::Neg(se) => Some(F::ZERO - se.eval(assignments)?), Expr::Pow(se, exp) => Some(se.eval(assignments)?.pow([*exp as u64])), Expr::Query(q) => assignments.get(q).copied(), @@ -270,7 +270,7 @@ impl ConstrDecomp { #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{field::Field, poly::VarAssignments}; @@ -322,4 +322,69 @@ mod test { assert_eq!(experiment.eval(&assignments), None) } + + #[test] + fn test_degree_expr() { + use super::Expr::*; + + let expr: Expr = + (Query("a") * Query("a")) + (Query("c") * Query("d")) - Const(Fr::ONE); + + assert_eq!(expr.degree(), 2); + + let expr: Expr = + (Query("a") * Query("a")) + (Query("c") * Query("d")) * Query("e"); + + assert_eq!(expr.degree(), 3); + } + + #[test] + fn test_expr_sum() { + use super::Expr::*; + + let lhs: Expr = Query("a") + Query("b"); + + let rhs: Expr = Query("c") + Query("d"); + + assert_eq!( + format!("({:?} + {:?})", lhs, rhs), + format!("{:?}", Sum(vec![lhs, rhs])) + ); + } + + #[test] + fn test_expr_mul() { + use super::Expr::*; + + let lhs: Expr = Query("a") * Query("b"); + + let rhs: Expr = Query("c") * Query("d"); + + assert_eq!( + format!("({:?} * {:?})", lhs, rhs), + format!("{:?}", Mul(vec![lhs, rhs])) + ); + } + + #[test] + fn test_expr_neg() { + use super::Expr::*; + + let expr: Expr = Query("a") + Query("b"); + + assert_eq!( + format!("(-{:?})", expr), + format!("{:?}", Neg(Box::new(expr))) + ); + + let lhs: Expr = Query("a") * Query("b"); + let rhs: Expr = Query("c") + Query("d"); + + let expr: Expr = lhs.clone() - rhs.clone(); + + assert_eq!( + format!("{:?}", Sum(vec![lhs, Neg(Box::new(rhs))])), + format!("{:?}", expr) + ); + } } diff --git a/src/poly/reduce.rs b/src/poly/reduce.rs index a928a26a..64625788 100644 --- a/src/poly/reduce.rs +++ b/src/poly/reduce.rs @@ -181,8 +181,7 @@ fn reduce_degree_mul( #[cfg(test)] mod test { - use halo2curves::bn256::Fr; + use halo2_proofs::halo2curves::bn256::Fr; use crate::{ poly::{ diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index b5a77029..217f2ba8 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -204,6 +204,28 @@ impl SBPIR { } } +impl SBPIR { + pub fn clone_without_trace(&self) -> SBPIR { + SBPIR { + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), + trace: None, // Remove the trace. + fixed_assignments: self.fixed_assignments.clone(), + first_step: self.first_step, + last_step: self.last_step, + num_steps: self.num_steps, + q_enable: self.q_enable, + id: self.id, + } + } +} + pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; pub type StepTypeUUID = UUID; @@ -253,6 +275,10 @@ impl StepType { self.id } + pub fn name(&self) -> String { + self.name.clone() + } + pub fn add_signal>(&mut self, name: N) -> InternalSignal { let name = name.into(); let signal = InternalSignal::new(name.clone()); @@ -417,6 +443,10 @@ impl ForwardSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -450,6 +480,10 @@ impl SharedSignal { pub fn phase(&self) -> usize { self.phase } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -476,6 +510,10 @@ impl FixedSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug)] @@ -508,6 +546,10 @@ impl InternalSignal { pub fn uuid(&self) -> UUID { self.id } + + pub fn annotation(&self) -> String { + self.annotation.to_string() + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] diff --git a/src/sbpir/query.rs b/src/sbpir/query.rs index 82cf8f45..5701b0d1 100644 --- a/src/sbpir/query.rs +++ b/src/sbpir/query.rs @@ -211,4 +211,161 @@ mod tests { let expr5: Expr> = Expr::Pow(Box::new(Expr::Const(a)), 2); assert_eq!(format!("{:?}", expr5), "(0xa)^2"); } + + #[test] + fn test_next_for_forward_signal() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, false); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Forward(forward_signal, true)); + } + + #[test] + #[should_panic(expected = "jarrl: cannot rotate next(forward)")] + fn test_next_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.next(); // This should panic + } + + #[test] + fn test_next_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 0); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Shared(shared_signal, 1)); + } + + #[test] + fn test_next_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 0); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Fixed(fixed_signal, 1)); + } + + #[test] + #[should_panic(expected = "can only next a forward, shared, fixed, or halo2 column")] + fn test_next_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.next(); // This should panic + } + + #[test] + fn test_prev_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 1); + let prev_queriable = queriable.prev(); + + assert_eq!(prev_queriable, Queriable::Shared(shared_signal, 0)); + } + + #[test] + fn test_prev_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 1); + let prev_queriable = queriable.prev(); + + assert_eq!(prev_queriable, Queriable::Fixed(fixed_signal, 0)); + } + + #[test] + #[should_panic(expected = "can only prev a shared or fixed column")] + fn test_prev_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.prev(); // This should panic + } + + #[test] + #[should_panic(expected = "can only prev a shared or fixed column")] + fn test_prev_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.prev(); // This should panic + } + + #[test] + fn test_rot_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 1); + let rot_queriable = queriable.rot(2); + + assert_eq!(rot_queriable, Queriable::Shared(shared_signal, 3)); + } + + #[test] + fn test_rot_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 1); + let rot_queriable = queriable.rot(2); + + assert_eq!(rot_queriable, Queriable::Fixed(fixed_signal, 3)); + } + + #[test] + #[should_panic(expected = "can only rot a shared or fixed column")] + fn test_rot_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.rot(2); // This should panic + } + + #[test] + #[should_panic(expected = "can only rot a shared or fixed column")] + fn test_rot_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.rot(2); // This should panic + } }