diff --git a/cs_derive/src/lib.rs b/cs_derive/src/lib.rs index a644692..45996c0 100644 --- a/cs_derive/src/lib.rs +++ b/cs_derive/src/lib.rs @@ -6,6 +6,7 @@ mod allocatable; mod selectable; pub(crate) mod utils; mod var_length_encodable; +mod witness_var_length_encodable; mod witness_hook; #[proc_macro_derive(CSSelectable, attributes(CSSelectableBound))] @@ -31,3 +32,9 @@ pub fn derive_witness_hook(input: TokenStream) -> TokenStream { pub fn derive_var_length_encodable(input: TokenStream) -> TokenStream { self::var_length_encodable::derive_var_length_encodable(input) } + +#[proc_macro_derive(WitVarLengthEncodable, attributes(WitnessVarLengthEncodableBound))] +#[proc_macro_error::proc_macro_error] +pub fn derive_witness_length_encodable(input: TokenStream) -> TokenStream { + self::witness_var_length_encodable::derive_witness_var_length_encodable(input) +} diff --git a/cs_derive/src/var_length_encodable/mod.rs b/cs_derive/src/var_length_encodable/mod.rs index e59598d..b481c38 100644 --- a/cs_derive/src/var_length_encodable/mod.rs +++ b/cs_derive/src/var_length_encodable/mod.rs @@ -24,8 +24,6 @@ pub(crate) fn derive_var_length_encodable( let mut length_impls = TokenStream::new(); let mut field_impls = TokenStream::new(); - let mut witness_to_buffer_impls = TokenStream::new(); - let mut witness_length_impls = TokenStream::new(); let extra_bound = if let Some(bound) = fetch_attr_from_list(BOUND_ATTR_NAME, &attrs) { let bound = syn::parse_str::(&bound).expect("must parse bound as WhereClause"); @@ -43,7 +41,7 @@ pub(crate) fn derive_var_length_encodable( for field in named_fields.named.iter() { let field_ident = field.ident.clone().expect("should have a field elem ident"); match field.ty { - Type::Array(ref array_ty) => { + Type::Array(ref _array_ty) => { let field_impl = quote! { total_len += CircuitVarLengthEncodable::::encoding_length(&self.#field_ident); }; @@ -53,16 +51,8 @@ pub(crate) fn derive_var_length_encodable( CircuitVarLengthEncodable::::encode_to_buffer(&self.#field_ident, cs, dst); }; field_impls.extend(field_impl); - let wit_to_buf_impl = quote! { - <#array_ty as CircuitVarLengthEncodable>::encode_witness_to_buffer(&witness.#field_ident, dst); - }; - witness_to_buffer_impls.extend(wit_to_buf_impl); - let wit_length_impl = quote! { - total_len += <#array_ty as CircuitVarLengthEncodable>::witness_encoding_length(&witness.#field_ident); - }; - witness_length_impls.extend(wit_length_impl); } - Type::Path(ref path_ty) => { + Type::Path(ref _path_ty) => { let field_impl = quote! { total_len += CircuitVarLengthEncodable::::encoding_length(&self.#field_ident); }; @@ -71,14 +61,6 @@ pub(crate) fn derive_var_length_encodable( CircuitVarLengthEncodable::::encode_to_buffer(&self.#field_ident, cs, dst); }; field_impls.extend(field_impl); - let wit_to_buf_impl = quote! { - <#path_ty as CircuitVarLengthEncodable>::encode_witness_to_buffer(&witness.#field_ident, dst); - }; - witness_to_buffer_impls.extend(wit_to_buf_impl); - let wit_length_impl = quote! { - total_len += <#path_ty as CircuitVarLengthEncodable>::witness_encoding_length(&witness.#field_ident); - }; - witness_length_impls.extend(wit_length_impl); } _ => abort_call_site!("only array and path types are allowed"), }; @@ -124,15 +106,6 @@ pub(crate) fn derive_var_length_encodable( fn encode_to_buffer>(&self, cs: &mut CS, dst: &mut Vec) { #field_impls } - fn witness_encoding_length(witness: &Self::Witness) -> usize { - let mut total_len = 0; - #witness_length_impls - - total_len - } - fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { - #witness_to_buffer_impls - } } }; diff --git a/cs_derive/src/witness_var_length_encodable/mod.rs b/cs_derive/src/witness_var_length_encodable/mod.rs new file mode 100644 index 0000000..195cc6a --- /dev/null +++ b/cs_derive/src/witness_var_length_encodable/mod.rs @@ -0,0 +1,98 @@ +use proc_macro2::{Span, TokenStream}; +use proc_macro_error::abort_call_site; +use quote::quote; +use syn::{ + parse_macro_input, token::Comma, DeriveInput, GenericParam, Type, + WhereClause, +}; + +use crate::utils::*; + +const BOUND_ATTR_NAME: &'static str = "WitnessVarLengthEncodableBound"; + +pub(crate) fn derive_witness_var_length_encodable( + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let derived_input = parse_macro_input!(input as DeriveInput); + let DeriveInput { + ident, + generics, + data, + attrs, + .. + } = derived_input.clone(); + + let mut witness_to_buffer_impls = TokenStream::new(); + let mut witness_length_impls = TokenStream::new(); + + let extra_bound = if let Some(bound) = fetch_attr_from_list(BOUND_ATTR_NAME, &attrs) { + let bound = syn::parse_str::(&bound).expect("must parse bound as WhereClause"); + + Some(bound) + } else { + None + }; + + let bound = merge_where_clauses(generics.where_clause.clone(), extra_bound); + + match data { + syn::Data::Struct(ref struct_data) => match struct_data.fields { + syn::Fields::Named(ref named_fields) => { + for field in named_fields.named.iter() { + let field_ident = field.ident.clone().expect("should have a field elem ident"); + match field.ty { + Type::Array(ref array_ty) => { + let wit_to_buf_impl = quote! { + <#array_ty as WitnessVarLengthEncodable>::encode_witness_to_buffer(&witness.#field_ident, dst); + }; + witness_to_buffer_impls.extend(wit_to_buf_impl); + let wit_length_impl = quote! { + total_len += <#array_ty as WitnessVarLengthEncodable>::witness_encoding_length(&witness.#field_ident); + }; + witness_length_impls.extend(wit_length_impl); + } + Type::Path(ref path_ty) => { + let wit_to_buf_impl = quote! { + <#path_ty as WitnessVarLengthEncodable>::encode_witness_to_buffer(&witness.#field_ident, dst); + }; + witness_to_buffer_impls.extend(wit_to_buf_impl); + let wit_length_impl = quote! { + total_len += <#path_ty as WitnessVarLengthEncodable>::witness_encoding_length(&witness.#field_ident); + }; + witness_length_impls.extend(wit_length_impl); + } + _ => abort_call_site!("only array and path types are allowed"), + }; + } + } + _ => abort_call_site!("only named fields are allowed!"), + }, + _ => abort_call_site!("only struct types are allowed!"), + } + + let comma = Comma(Span::call_site()); + + let field_generic_param = syn::parse_str::(&"F: SmallField").unwrap(); + let has_engine_param = has_proper_small_field_parameter(&generics.params, &field_generic_param); + if has_engine_param == false { + panic!("Expected to have `F: SmallField` somewhere in bounds"); + } + + let type_params_of_allocated_struct = get_type_params_from_generics(&generics, &comma); + + let expanded = quote! { + impl #generics WitnessVarLengthEncodable for #ident<#type_params_of_allocated_struct> #bound { + fn witness_encoding_length(witness: &Self::Witness) -> usize { + let mut total_len = 0; + #witness_length_impls + + total_len + } + fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec) { + #witness_to_buffer_impls + } + } + }; + + proc_macro::TokenStream::from(expanded) +} diff --git a/src/gadgets/boolean/mod.rs b/src/gadgets/boolean/mod.rs index ed95ddb..855c779 100644 --- a/src/gadgets/boolean/mod.rs +++ b/src/gadgets/boolean/mod.rs @@ -702,7 +702,7 @@ impl Boolean { } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for Boolean { #[inline(always)] @@ -712,6 +712,9 @@ impl CircuitVarLengthEncodable for Boolean { fn encode_to_buffer>(&self, _cs: &mut CS, dst: &mut Vec) { dst.push(self.variable); } +} + +impl WitnessVarLengthEncodable for Boolean { fn witness_encoding_length(_witness: &Self::Witness) -> usize { 1 } diff --git a/src/gadgets/num/mod.rs b/src/gadgets/num/mod.rs index ce323ed..7967061 100644 --- a/src/gadgets/num/mod.rs +++ b/src/gadgets/num/mod.rs @@ -1214,7 +1214,7 @@ pub fn dot_product_using_dot_product_gate result } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for Num { #[inline(always)] @@ -1224,6 +1224,9 @@ impl CircuitVarLengthEncodable for Num { fn encode_to_buffer>(&self, _cs: &mut CS, dst: &mut Vec) { dst.push(self.get_variable()); } +} + +impl WitnessVarLengthEncodable for Num { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 1 diff --git a/src/gadgets/queue/mod.rs b/src/gadgets/queue/mod.rs index 20f3221..30234bf 100644 --- a/src/gadgets/queue/mod.rs +++ b/src/gadgets/queue/mod.rs @@ -582,14 +582,14 @@ pub fn simulate_new_tail< use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; use crate::serde_utils::BigArraySerde; -#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)] +#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitVarLengthEncodable, WitnessHookable)] #[derivative(Clone, Copy, Debug)] pub struct QueueState { pub head: [Num; N], pub tail: QueueTailState, } -#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)] +#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitVarLengthEncodable, WitnessHookable)] #[derivative(Clone, Copy, Debug)] pub struct QueueTailState { pub tail: [Num; N], diff --git a/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs b/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs index 2ccb759..0efb1c3 100644 --- a/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs +++ b/src/gadgets/queue/queue_optimizer/sponge_optimizer.rs @@ -1,6 +1,6 @@ use super::*; -#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)] +#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitVarLengthEncodable, WitnessHookable)] #[derivative(Clone, Debug, Copy)] pub struct SpongeRoundRequest { pub initial_state: [Num; SWIDTH], diff --git a/src/gadgets/recursion/allocated_vk.rs b/src/gadgets/recursion/allocated_vk.rs index d8b8f97..fdf75c9 100644 --- a/src/gadgets/recursion/allocated_vk.rs +++ b/src/gadgets/recursion/allocated_vk.rs @@ -101,7 +101,7 @@ where } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl>> CircuitVarLengthEncodable for AllocatedVerificationKey @@ -123,7 +123,11 @@ impl>> CircuitVarLengthEncodable el.encode_to_buffer(cs, dst); } } +} +impl>> WitnessVarLengthEncodable + for AllocatedVerificationKey +{ fn witness_encoding_length(witness: &Self::Witness) -> usize { let cap_size = witness.setup_merkle_tree_cap.len(); assert!(cap_size > 0); @@ -138,3 +142,4 @@ impl>> CircuitVarLengthEncodable } } } + diff --git a/src/gadgets/recursion/recursive_tree_hasher.rs b/src/gadgets/recursion/recursive_tree_hasher.rs index 4a04fa2..60737a6 100644 --- a/src/gadgets/recursion/recursive_tree_hasher.rs +++ b/src/gadgets/recursion/recursive_tree_hasher.rs @@ -25,7 +25,8 @@ pub trait CircuitTreeHasher>: + Eq + std::fmt::Debug + CSAllocatable - + CircuitVarLengthEncodable; + + CircuitVarLengthEncodable + + WitnessVarLengthEncodable; fn new>(cs: &mut CS) -> Self; fn placeholder_output>(cs: &mut CS) -> Self::CircuitOutput; @@ -87,6 +88,8 @@ pub trait RecursiveTreeHasher>: use crate::gadgets::round_function::CircuitSimpleAlgebraicSponge; +use super::traits::encodable::WitnessVarLengthEncodable; + impl< F: SmallField, const AW: usize, diff --git a/src/gadgets/traits/encodable.rs b/src/gadgets/traits/encodable.rs index 36b13af..818d3df 100644 --- a/src/gadgets/traits/encodable.rs +++ b/src/gadgets/traits/encodable.rs @@ -7,6 +7,11 @@ pub trait CircuitEncodable: 'static + Send + Sync + CSAllocatable { fn encode>(&self, cs: &mut CS) -> [Variable; N]; +} + +pub trait WitnessEncodable: + 'static + Send + Sync + CSAllocatable +{ fn encode_witness(witness: &Self::Witness, dst: &mut Vec); } @@ -20,7 +25,11 @@ pub trait CircuitVarLengthEncodable: { fn encoding_length(&self) -> usize; fn encode_to_buffer>(&self, cs: &mut CS, dst: &mut Vec); +} +pub trait WitnessVarLengthEncodable: + 'static + Send + Sync + CSAllocatable +{ fn witness_encoding_length(witness: &Self::Witness) -> usize; fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec); } @@ -52,7 +61,11 @@ impl> CircuitVarL el.encode_to_buffer(cs, dst); } } +} +impl> WitnessVarLengthEncodable + for [T; N] +{ fn witness_encoding_length(witness: &Self::Witness) -> usize { debug_assert!(N > 0); @@ -74,6 +87,9 @@ impl CircuitVarLengthEncodable for () { fn encode_to_buffer>(&self, _cs: &mut CS, _dst: &mut Vec) { // do nothing } +} + +impl WitnessVarLengthEncodable for () { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 0 diff --git a/src/gadgets/u16/mod.rs b/src/gadgets/u16/mod.rs index 083240f..2041552 100644 --- a/src/gadgets/u16/mod.rs +++ b/src/gadgets/u16/mod.rs @@ -559,7 +559,7 @@ impl UInt16 { } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt16 { #[inline(always)] @@ -569,6 +569,9 @@ impl CircuitVarLengthEncodable for UInt16 { fn encode_to_buffer>(&self, _cs: &mut CS, dst: &mut Vec) { dst.push(self.variable); } +} + +impl WitnessVarLengthEncodable for UInt16 { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 1 diff --git a/src/gadgets/u160/mod.rs b/src/gadgets/u160/mod.rs index c90cae0..da6965c 100644 --- a/src/gadgets/u160/mod.rs +++ b/src/gadgets/u160/mod.rs @@ -245,7 +245,7 @@ use crate::gadgets::traits::selectable::MultiSelectable; // so we degrade to default impl via normal select impl MultiSelectable for UInt160 {} -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt160 { #[inline(always)] @@ -255,6 +255,9 @@ impl CircuitVarLengthEncodable for UInt160 { fn encode_to_buffer>(&self, cs: &mut CS, dst: &mut Vec) { CircuitVarLengthEncodable::::encode_to_buffer(&self.inner, cs, dst); } +} + +impl WitnessVarLengthEncodable for UInt160 { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 5 diff --git a/src/gadgets/u256/mod.rs b/src/gadgets/u256/mod.rs index 420e816..9006642 100644 --- a/src/gadgets/u256/mod.rs +++ b/src/gadgets/u256/mod.rs @@ -403,7 +403,7 @@ use crate::gadgets::traits::selectable::MultiSelectable; // so we degrade to default impl via normal select impl MultiSelectable for UInt256 {} -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt256 { #[inline(always)] @@ -413,6 +413,9 @@ impl CircuitVarLengthEncodable for UInt256 { fn encode_to_buffer>(&self, cs: &mut CS, dst: &mut Vec) { CircuitVarLengthEncodable::::encode_to_buffer(&self.inner, cs, dst); } +} + +impl WitnessVarLengthEncodable for UInt256 { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 8 diff --git a/src/gadgets/u32/mod.rs b/src/gadgets/u32/mod.rs index d9807ba..18b2980 100644 --- a/src/gadgets/u32/mod.rs +++ b/src/gadgets/u32/mod.rs @@ -876,7 +876,7 @@ impl UInt32 { } } -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt32 { #[inline(always)] @@ -886,6 +886,9 @@ impl CircuitVarLengthEncodable for UInt32 { fn encode_to_buffer>(&self, _cs: &mut CS, dst: &mut Vec) { dst.push(self.variable); } +} + +impl WitnessVarLengthEncodable for UInt32 { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 1 diff --git a/src/gadgets/u512/mod.rs b/src/gadgets/u512/mod.rs index 4d05311..50ec3df 100644 --- a/src/gadgets/u512/mod.rs +++ b/src/gadgets/u512/mod.rs @@ -376,7 +376,7 @@ use crate::gadgets::traits::selectable::MultiSelectable; // so we degrade to default impl via normal select impl MultiSelectable for UInt512 {} -use crate::gadgets::traits::encodable::CircuitVarLengthEncodable; +use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable}; impl CircuitVarLengthEncodable for UInt512 { #[inline(always)] @@ -386,6 +386,9 @@ impl CircuitVarLengthEncodable for UInt512 { fn encode_to_buffer>(&self, cs: &mut CS, dst: &mut Vec) { CircuitVarLengthEncodable::::encode_to_buffer(&self.inner, cs, dst); } +} + +impl WitnessVarLengthEncodable for UInt512 { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 16 diff --git a/src/gadgets/u8/mod.rs b/src/gadgets/u8/mod.rs index 486df33..12f39db 100644 --- a/src/gadgets/u8/mod.rs +++ b/src/gadgets/u8/mod.rs @@ -1,3 +1,5 @@ +use traits::encodable::WitnessVarLengthEncodable; + use super::tables::ch4::Ch4Table; use super::tables::trixor4::TriXor4Table; use super::tables::xor8::Xor8Table; @@ -577,6 +579,9 @@ impl CircuitVarLengthEncodable for UInt8 { fn encode_to_buffer>(&self, _cs: &mut CS, dst: &mut Vec) { dst.push(self.variable); } +} + +impl WitnessVarLengthEncodable for UInt8 { #[inline(always)] fn witness_encoding_length(_witness: &Self::Witness) -> usize { 1