diff --git a/cs_derive/src/lib.rs b/cs_derive/src/lib.rs index a644692..61880ea 100644 --- a/cs_derive/src/lib.rs +++ b/cs_derive/src/lib.rs @@ -7,6 +7,7 @@ mod selectable; pub(crate) mod utils; mod var_length_encodable; mod witness_hook; +mod witness_var_length_encodable; #[proc_macro_derive(CSSelectable, attributes(CSSelectableBound))] #[proc_macro_error::proc_macro_error] @@ -31,3 +32,9 @@ pub fn derive_witness_hook(input: TokenStream) -> TokenStream { pub fn derive_var_length_encodable(input: TokenStream) -> TokenStream { self::var_length_encodable::derive_var_length_encodable(input) } + +#[proc_macro_derive(WitVarLengthEncodable, attributes(WitnessVarLengthEncodableBound))] +#[proc_macro_error::proc_macro_error] +pub fn derive_witness_length_encodable(input: TokenStream) -> TokenStream { + self::witness_var_length_encodable::derive_witness_var_length_encodable(input) +} diff --git a/cs_derive/src/var_length_encodable/mod.rs b/cs_derive/src/var_length_encodable/mod.rs index ed99ff1..b481c38 100644 --- a/cs_derive/src/var_length_encodable/mod.rs +++ b/cs_derive/src/var_length_encodable/mod.rs @@ -52,7 +52,7 @@ pub(crate) fn derive_var_length_encodable( }; field_impls.extend(field_impl); } - Type::Path(_) => { + Type::Path(ref _path_ty) => { let field_impl = quote! { total_len += CircuitVarLengthEncodable::::encoding_length(&self.#field_ident); }; diff --git a/cs_derive/src/witness_var_length_encodable/mod.rs b/cs_derive/src/witness_var_length_encodable/mod.rs new file mode 100644 index 0000000..44a5b21 --- /dev/null +++ b/cs_derive/src/witness_var_length_encodable/mod.rs @@ -0,0 +1,95 @@ +use proc_macro2::{Span, TokenStream}; +use proc_macro_error::abort_call_site; +use quote::quote; +use syn::{parse_macro_input, token::Comma, DeriveInput, GenericParam, Type, WhereClause}; + +use crate::utils::*; + +const BOUND_ATTR_NAME: &'static str = "WitnessVarLengthEncodableBound"; + +pub(crate) fn derive_witness_var_length_encodable( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let derived_input = parse_macro_input!(input as DeriveInput); + let DeriveInput { + ident, + generics, + data, + attrs, + .. + } = derived_input.clone(); + + let mut witness_to_buffer_impls = TokenStream::new(); + let mut witness_length_impls = TokenStream::new(); + + let extra_bound = if let Some(bound) = fetch_attr_from_list(BOUND_ATTR_NAME, &attrs) { + let bound = syn::parse_str::(&bound).expect("must parse bound as WhereClause"); + + Some(bound) + } else { + None + }; + + let bound = merge_where_clauses(generics.where_clause.clone(), extra_bound); + + match data { + syn::Data::Struct(ref struct_data) => match struct_data.fields { + syn::Fields::Named(ref named_fields) => { + for field in named_fields.named.iter() { + let field_ident = field.ident.clone().expect("should have a field elem ident"); + match field.ty { + Type::Array(ref array_ty) => { + let wit_to_buf_impl = quote! { + <#array_ty as WitnessVarLengthEncodable>::encode_witness_to_buffer(&witness.#field_ident, dst); + }; + witness_to_buffer_impls.extend(wit_to_buf_impl); + let wit_length_impl = quote! { + total_len += <#array_ty as WitnessVarLengthEncodable>::witness_encoding_length(&witness.#field_ident); + }; + witness_length_impls.extend(wit_length_impl); + } + Type::Path(ref path_ty) => { + let wit_to_buf_impl = quote! { + <#path_ty as WitnessVarLengthEncodable>::encode_witness_to_buffer(&witness.#field_ident, dst); + }; + witness_to_buffer_impls.extend(wit_to_buf_impl); + let wit_length_impl = quote! { + total_len += <#path_ty as WitnessVarLengthEncodable>::witness_encoding_length(&witness.#field_ident); + }; + witness_length_impls.extend(wit_length_impl); + } + _ => abort_call_site!("only array and path types are allowed"), + }; + } + } + _ => abort_call_site!("only named fields are allowed!"), + }, + _ => abort_call_site!("only struct types are allowed!"), + } + + let comma = Comma(Span::call_site()); + + let field_generic_param = syn::parse_str::(&"F: SmallField").unwrap(); + let has_engine_param = has_proper_small_field_parameter(&generics.params, &field_generic_param); + if has_engine_param == false { + panic!("Expected to have `F: SmallField` somewhere in bounds"); + } + + let type_params_of_allocated_struct = get_type_params_from_generics(&generics, &comma); + + let expanded = quote! { + impl #generics WitnessVarLengthEncodable for #ident<#type_params_of_allocated_struct> #bound { + fn witness_encoding_length(witness: &Self::Witness) -> usize { + let mut total_len = 0; + #witness_length_impls + + total_len + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + #witness_to_buffer_impls + } + } + }; + + proc_macro::TokenStream::from(expanded) +} diff --git a/src/gadgets/boolean/mod.rs b/src/gadgets/boolean/mod.rs index fc3959f..855c779 100644 --- a/src/gadgets/boolean/mod.rs +++ b/src/gadgets/boolean/mod.rs @@ -702,7 +702,7 @@ impl Boolean { } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for Boolean { #[inline(always)] @@ -713,3 +713,13 @@ impl CircuitVarLengthEncodable for Boolean { dst.push(self.variable); } } + +impl WitnessVarLengthEncodable for Boolean { + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 1 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + let val = F::from_raw_u64_unchecked(*witness as u64); + dst.push(val); + } +} diff --git a/src/gadgets/num/mod.rs b/src/gadgets/num/mod.rs index fd7e97b..7967061 100644 --- a/src/gadgets/num/mod.rs +++ b/src/gadgets/num/mod.rs @@ -1214,7 +1214,7 @@ pub fn dot_product_using_dot_product_gate result } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for Num { #[inline(always)] @@ -1226,6 +1226,16 @@ impl CircuitVarLengthEncodable for Num { } } +impl WitnessVarLengthEncodable for Num { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 1 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + dst.push(*witness); + } +} + use crate::gadgets::traits::allocatable::CSAllocatableExt; impl CSAllocatableExt for Num { diff --git a/src/gadgets/queue/mod.rs b/src/gadgets/queue/mod.rs index 20f3221..88e046f 100644 --- a/src/gadgets/queue/mod.rs +++ b/src/gadgets/queue/mod.rs @@ -6,7 +6,7 @@ use super::boolean::Boolean; use super::num::Num; use super::traits::allocatable::*; use super::u32::UInt32; -use super::{traits::encodable::CircuitEncodable, *}; +use super::{traits::encodable::CircuitEncodable, traits::encodable::WitnessVarLengthEncodable, *}; use crate::algebraic_props::round_function::AbsorptionModeOverwrite; use crate::algebraic_props::round_function::AlgebraicRoundFunction; use crate::config::CSConfig; @@ -582,14 +582,28 @@ pub fn simulate_new_tail< use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; use crate::serde_utils::BigArraySerde; -#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)] +#[derive( + Derivative, + CSAllocatable, + CSSelectable, + CSVarLengthEncodable, + WitVarLengthEncodable, + WitnessHookable, +)] #[derivative(Clone, Copy, Debug)] pub struct QueueState { pub head: [Num; N], pub tail: QueueTailState, } -#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)] +#[derive( + Derivative, + CSAllocatable, + CSSelectable, + CSVarLengthEncodable, + WitVarLengthEncodable, + WitnessHookable, +)] #[derivative(Clone, Copy, Debug)] pub struct QueueTailState { pub tail: [Num; N], diff --git a/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs b/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs index 2ccb759..e020af3 100644 --- a/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs +++ b/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs @@ -1,6 +1,13 @@ use super::*; -#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)] +#[derive( + Derivative, + CSAllocatable, + CSSelectable, + CSVarLengthEncodable, + WitVarLengthEncodable, + WitnessHookable, +)] #[derivative(Clone, Debug, Copy)] pub struct SpongeRoundRequest { pub initial_state: [Num; SWIDTH], diff --git a/src/gadgets/recursion/allocated_vk.rs b/src/gadgets/recursion/allocated_vk.rs index 2cac8a7..4772469 100644 --- a/src/gadgets/recursion/allocated_vk.rs +++ b/src/gadgets/recursion/allocated_vk.rs @@ -101,7 +101,7 @@ where } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl>> CircuitVarLengthEncodable for AllocatedVerificationKey @@ -124,3 +124,21 @@ impl>> CircuitVarLengthEncodable } } } + +impl>> WitnessVarLengthEncodable + for AllocatedVerificationKey +{ + fn witness_encoding_length(witness: &Self::Witness) -> usize { + let cap_size = witness.setup_merkle_tree_cap.len(); + assert!(cap_size > 0); + let el_size = H::CircuitOutput::witness_encoding_length(&witness.setup_merkle_tree_cap[0]); + + el_size * cap_size + } + + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + for el in witness.setup_merkle_tree_cap.iter() { + H::CircuitOutput::encode_witness_to_buffer(el, dst) + } + } +} diff --git a/src/gadgets/recursion/recursive_tree_hasher.rs b/src/gadgets/recursion/recursive_tree_hasher.rs index 4a04fa2..60737a6 100644 --- a/src/gadgets/recursion/recursive_tree_hasher.rs +++ b/src/gadgets/recursion/recursive_tree_hasher.rs @@ -25,7 +25,8 @@ pub trait CircuitTreeHasher>: + Eq + std::fmt::Debug + CSAllocatable - + CircuitVarLengthEncodable; + + CircuitVarLengthEncodable + + WitnessVarLengthEncodable; fn new>(cs: &mut CS) -> Self; fn placeholder_output>(cs: &mut CS) -> Self::CircuitOutput; @@ -87,6 +88,8 @@ pub trait RecursiveTreeHasher>: use crate::gadgets::round_function::CircuitSimpleAlgebraicSponge; +use super::traits::encodable::WitnessVarLengthEncodable; + impl< F: SmallField, const AW: usize, diff --git a/src/gadgets/traits/encodable.rs b/src/gadgets/traits/encodable.rs index fbf63e8..818d3df 100644 --- a/src/gadgets/traits/encodable.rs +++ b/src/gadgets/traits/encodable.rs @@ -9,6 +9,12 @@ pub trait CircuitEncodable: fn encode>(&self, cs: &mut CS) -> [Variable; N]; } +pub trait WitnessEncodable: + 'static + Send + Sync + CSAllocatable +{ + fn encode_witness(witness: &Self::Witness, dst: &mut Vec); +} + pub trait CircuitEncodableExt: CircuitEncodable + CSAllocatableExt { @@ -21,6 +27,13 @@ pub trait CircuitVarLengthEncodable: fn encode_to_buffer>(&self, cs: &mut CS, dst: &mut Vec); } +pub trait WitnessVarLengthEncodable: + 'static + Send + Sync + CSAllocatable +{ + fn witness_encoding_length(witness: &Self::Witness) -> usize; + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec); +} + // unfortunately default implementation is impossible as compiler can not have constraint "for all N" // impl> CircuitVarLengthEncodable for T { @@ -50,6 +63,22 @@ impl> CircuitVarL } } +impl> WitnessVarLengthEncodable + for [T; N] +{ + fn witness_encoding_length(witness: &Self::Witness) -> usize { + debug_assert!(N > 0); + + N * T::witness_encoding_length(&witness[0]) + } + + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + for el in witness.iter() { + T::encode_witness_to_buffer(el, dst); + } + } +} + impl CircuitVarLengthEncodable for () { #[inline(always)] fn encoding_length(&self) -> usize { @@ -59,3 +88,13 @@ impl CircuitVarLengthEncodable for () { // do nothing } } + +impl WitnessVarLengthEncodable for () { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 0 + } + fn encode_witness_to_buffer(_witness: &Self::Witness, _dst: &mut Vec) { + // do nothing + } +} diff --git a/src/gadgets/u16/mod.rs b/src/gadgets/u16/mod.rs index 1d96b7b..2041552 100644 --- a/src/gadgets/u16/mod.rs +++ b/src/gadgets/u16/mod.rs @@ -559,7 +559,7 @@ impl UInt16 { } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt16 { #[inline(always)] @@ -570,3 +570,13 @@ impl CircuitVarLengthEncodable for UInt16 { dst.push(self.variable); } } + +impl WitnessVarLengthEncodable for UInt16 { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 1 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + dst.push(F::from_u64_with_reduction(*witness as u64)); + } +} diff --git a/src/gadgets/u160/mod.rs b/src/gadgets/u160/mod.rs index 13d8f53..da6965c 100644 --- a/src/gadgets/u160/mod.rs +++ b/src/gadgets/u160/mod.rs @@ -245,7 +245,7 @@ use crate::gadgets::traits::selectable::MultiSelectable; // so we degrade to default impl via normal select impl MultiSelectable for UInt160 {} -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt160 { #[inline(always)] @@ -257,6 +257,17 @@ impl CircuitVarLengthEncodable for UInt160 { } } +impl WitnessVarLengthEncodable for UInt160 { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 5 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + let chunks = decompose_address_as_u32x5(*witness); + chunks.map(|el| UInt32::encode_witness_to_buffer(&el, dst)); + } +} + use crate::gadgets::traits::allocatable::CSPlaceholder; impl CSPlaceholder for UInt160 { diff --git a/src/gadgets/u256/mod.rs b/src/gadgets/u256/mod.rs index 5047f97..9006642 100644 --- a/src/gadgets/u256/mod.rs +++ b/src/gadgets/u256/mod.rs @@ -403,7 +403,7 @@ use crate::gadgets::traits::selectable::MultiSelectable; // so we degrade to default impl via normal select impl MultiSelectable for UInt256 {} -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt256 { #[inline(always)] @@ -415,6 +415,17 @@ impl CircuitVarLengthEncodable for UInt256 { } } +impl WitnessVarLengthEncodable for UInt256 { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 8 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + let chunks = decompose_u256_as_u32x8(*witness); + chunks.map(|el| UInt32::encode_witness_to_buffer(&el, dst)); + } +} + use crate::gadgets::traits::allocatable::CSPlaceholder; impl CSPlaceholder for UInt256 { diff --git a/src/gadgets/u32/mod.rs b/src/gadgets/u32/mod.rs index 33cda3a..18b2980 100644 --- a/src/gadgets/u32/mod.rs +++ b/src/gadgets/u32/mod.rs @@ -876,7 +876,7 @@ impl UInt32 { } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt32 { #[inline(always)] @@ -888,6 +888,16 @@ impl CircuitVarLengthEncodable for UInt32 { } } +impl WitnessVarLengthEncodable for UInt32 { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 1 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + dst.push(F::from_u64_with_reduction(*witness as u64)); + } +} + use crate::gadgets::traits::allocatable::CSPlaceholder; impl CSPlaceholder for UInt32 { diff --git a/src/gadgets/u512/mod.rs b/src/gadgets/u512/mod.rs index bb3781b..50ec3df 100644 --- a/src/gadgets/u512/mod.rs +++ b/src/gadgets/u512/mod.rs @@ -376,7 +376,7 @@ use crate::gadgets::traits::selectable::MultiSelectable; // so we degrade to default impl via normal select impl MultiSelectable for UInt512 {} -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt512 { #[inline(always)] @@ -388,6 +388,17 @@ impl CircuitVarLengthEncodable for UInt512 { } } +impl WitnessVarLengthEncodable for UInt512 { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 16 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + let chunks = decompose_u512_as_u32x16(*witness); + chunks.map(|el| UInt32::encode_witness_to_buffer(&el, dst)); + } +} + use crate::gadgets::traits::allocatable::CSPlaceholder; impl CSPlaceholder for UInt512 { diff --git a/src/gadgets/u8/mod.rs b/src/gadgets/u8/mod.rs index 9490bb4..12f39db 100644 --- a/src/gadgets/u8/mod.rs +++ b/src/gadgets/u8/mod.rs @@ -1,3 +1,5 @@ +use traits::encodable::WitnessVarLengthEncodable; + use super::tables::ch4::Ch4Table; use super::tables::trixor4::TriXor4Table; use super::tables::xor8::Xor8Table; @@ -579,6 +581,16 @@ impl CircuitVarLengthEncodable for UInt8 { } } +impl WitnessVarLengthEncodable for UInt8 { + #[inline(always)] + fn witness_encoding_length(_witness: &Self::Witness) -> usize { + 1 + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + dst.push(F::from_u64_with_reduction(*witness as u64)); + } +} + use crate::gadgets::traits::allocatable::CSPlaceholder; impl CSPlaceholder for UInt8 {