Skip to content
This repository has been archived by the owner on Aug 15, 2024. It is now read-only.

Commit

Permalink
Move witness encoding logic to separate trait
Browse files Browse the repository at this point in the history
  • Loading branch information
0xVolosnikov committed Jul 30, 2024
1 parent 83bd5bf commit 0dd8673
Show file tree
Hide file tree
Showing 16 changed files with 169 additions and 41 deletions.
7 changes: 7 additions & 0 deletions cs_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod allocatable;
mod selectable;
pub(crate) mod utils;
mod var_length_encodable;
mod witness_var_length_encodable;
mod witness_hook;

#[proc_macro_derive(CSSelectable, attributes(CSSelectableBound))]
Expand All @@ -31,3 +32,9 @@ pub fn derive_witness_hook(input: TokenStream) -> TokenStream {
pub fn derive_var_length_encodable(input: TokenStream) -> TokenStream {
self::var_length_encodable::derive_var_length_encodable(input)
}

#[proc_macro_derive(WitVarLengthEncodable, attributes(WitnessVarLengthEncodableBound))]
#[proc_macro_error::proc_macro_error]
pub fn derive_witness_length_encodable(input: TokenStream) -> TokenStream {
self::witness_var_length_encodable::derive_witness_var_length_encodable(input)
}
31 changes: 2 additions & 29 deletions cs_derive/src/var_length_encodable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ pub(crate) fn derive_var_length_encodable(

let mut length_impls = TokenStream::new();
let mut field_impls = TokenStream::new();
let mut witness_to_buffer_impls = TokenStream::new();
let mut witness_length_impls = TokenStream::new();

let extra_bound = if let Some(bound) = fetch_attr_from_list(BOUND_ATTR_NAME, &attrs) {
let bound = syn::parse_str::<WhereClause>(&bound).expect("must parse bound as WhereClause");
Expand All @@ -43,7 +41,7 @@ pub(crate) fn derive_var_length_encodable(
for field in named_fields.named.iter() {
let field_ident = field.ident.clone().expect("should have a field elem ident");
match field.ty {
Type::Array(ref array_ty) => {
Type::Array(ref _array_ty) => {
let field_impl = quote! {
total_len += CircuitVarLengthEncodable::<F>::encoding_length(&self.#field_ident);
};
Expand All @@ -53,16 +51,8 @@ pub(crate) fn derive_var_length_encodable(
CircuitVarLengthEncodable::<F>::encode_to_buffer(&self.#field_ident, cs, dst);
};
field_impls.extend(field_impl);
let wit_to_buf_impl = quote! {
<#array_ty as CircuitVarLengthEncodable<F>>::encode_witness_to_buffer(&witness.#field_ident, dst);
};
witness_to_buffer_impls.extend(wit_to_buf_impl);
let wit_length_impl = quote! {
total_len += <#array_ty as CircuitVarLengthEncodable<F>>::witness_encoding_length(&witness.#field_ident);
};
witness_length_impls.extend(wit_length_impl);
}
Type::Path(ref path_ty) => {
Type::Path(ref _path_ty) => {
let field_impl = quote! {
total_len += CircuitVarLengthEncodable::<F>::encoding_length(&self.#field_ident);
};
Expand All @@ -71,14 +61,6 @@ pub(crate) fn derive_var_length_encodable(
CircuitVarLengthEncodable::<F>::encode_to_buffer(&self.#field_ident, cs, dst);
};
field_impls.extend(field_impl);
let wit_to_buf_impl = quote! {
<#path_ty as CircuitVarLengthEncodable<F>>::encode_witness_to_buffer(&witness.#field_ident, dst);
};
witness_to_buffer_impls.extend(wit_to_buf_impl);
let wit_length_impl = quote! {
total_len += <#path_ty as CircuitVarLengthEncodable<F>>::witness_encoding_length(&witness.#field_ident);
};
witness_length_impls.extend(wit_length_impl);
}
_ => abort_call_site!("only array and path types are allowed"),
};
Expand Down Expand Up @@ -124,15 +106,6 @@ pub(crate) fn derive_var_length_encodable(
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, cs: &mut CS, dst: &mut Vec<Variable>) {
#field_impls
}
fn witness_encoding_length(witness: &Self::Witness) -> usize {
let mut total_len = 0;
#witness_length_impls

total_len
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
#witness_to_buffer_impls
}
}
};

Expand Down
98 changes: 98 additions & 0 deletions cs_derive/src/witness_var_length_encodable/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use proc_macro2::{Span, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{
parse_macro_input, token::Comma, DeriveInput, GenericParam, Type,
WhereClause,
};

use crate::utils::*;

const BOUND_ATTR_NAME: &'static str = "WitnessVarLengthEncodableBound";

pub(crate) fn derive_witness_var_length_encodable(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let derived_input = parse_macro_input!(input as DeriveInput);
let DeriveInput {
ident,
generics,
data,
attrs,
..
} = derived_input.clone();

let mut witness_to_buffer_impls = TokenStream::new();
let mut witness_length_impls = TokenStream::new();

let extra_bound = if let Some(bound) = fetch_attr_from_list(BOUND_ATTR_NAME, &attrs) {
let bound = syn::parse_str::<WhereClause>(&bound).expect("must parse bound as WhereClause");

Some(bound)
} else {
None
};

let bound = merge_where_clauses(generics.where_clause.clone(), extra_bound);

match data {
syn::Data::Struct(ref struct_data) => match struct_data.fields {
syn::Fields::Named(ref named_fields) => {
for field in named_fields.named.iter() {
let field_ident = field.ident.clone().expect("should have a field elem ident");
match field.ty {
Type::Array(ref array_ty) => {
let wit_to_buf_impl = quote! {
<#array_ty as WitnessVarLengthEncodable<F>>::encode_witness_to_buffer(&witness.#field_ident, dst);
};
witness_to_buffer_impls.extend(wit_to_buf_impl);
let wit_length_impl = quote! {
total_len += <#array_ty as WitnessVarLengthEncodable<F>>::witness_encoding_length(&witness.#field_ident);
};
witness_length_impls.extend(wit_length_impl);
}
Type::Path(ref path_ty) => {
let wit_to_buf_impl = quote! {
<#path_ty as WitnessVarLengthEncodable<F>>::encode_witness_to_buffer(&witness.#field_ident, dst);
};
witness_to_buffer_impls.extend(wit_to_buf_impl);
let wit_length_impl = quote! {
total_len += <#path_ty as WitnessVarLengthEncodable<F>>::witness_encoding_length(&witness.#field_ident);
};
witness_length_impls.extend(wit_length_impl);
}
_ => abort_call_site!("only array and path types are allowed"),
};
}
}
_ => abort_call_site!("only named fields are allowed!"),
},
_ => abort_call_site!("only struct types are allowed!"),
}

let comma = Comma(Span::call_site());

let field_generic_param = syn::parse_str::<GenericParam>(&"F: SmallField").unwrap();
let has_engine_param = has_proper_small_field_parameter(&generics.params, &field_generic_param);
if has_engine_param == false {
panic!("Expected to have `F: SmallField` somewhere in bounds");
}

let type_params_of_allocated_struct = get_type_params_from_generics(&generics, &comma);

let expanded = quote! {
impl #generics WitnessVarLengthEncodable<F> for #ident<#type_params_of_allocated_struct> #bound {
fn witness_encoding_length(witness: &Self::Witness) -> usize {
let mut total_len = 0;
#witness_length_impls

total_len
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
#witness_to_buffer_impls
}
}
};

proc_macro::TokenStream::from(expanded)
}
5 changes: 4 additions & 1 deletion src/gadgets/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ impl<F: SmallField> Boolean<F> {
}
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for Boolean<F> {
#[inline(always)]
Expand All @@ -712,6 +712,9 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for Boolean<F> {
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, _cs: &mut CS, dst: &mut Vec<Variable>) {
dst.push(self.variable);
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for Boolean<F> {
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
1
}
Expand Down
5 changes: 4 additions & 1 deletion src/gadgets/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ pub fn dot_product_using_dot_product_gate<F: SmallField, CS: ConstraintSystem<F>
result
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for Num<F> {
#[inline(always)]
Expand All @@ -1224,6 +1224,9 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for Num<F> {
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, _cs: &mut CS, dst: &mut Vec<Variable>) {
dst.push(self.get_variable());
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for Num<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
1
Expand Down
4 changes: 2 additions & 2 deletions src/gadgets/queue/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,14 +582,14 @@ pub fn simulate_new_tail<
use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::serde_utils::BigArraySerde;

#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)]
#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitVarLengthEncodable, WitnessHookable)]
#[derivative(Clone, Copy, Debug)]
pub struct QueueState<F: SmallField, const N: usize> {
pub head: [Num<F>; N],
pub tail: QueueTailState<F, N>,
}

#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)]
#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitVarLengthEncodable, WitnessHookable)]
#[derivative(Clone, Copy, Debug)]
pub struct QueueTailState<F: SmallField, const N: usize> {
pub tail: [Num<F>; N],
Expand Down
2 changes: 1 addition & 1 deletion src/gadgets/queue/queue_optimizer/sponge_optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;

#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)]
#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitVarLengthEncodable, WitnessHookable)]
#[derivative(Clone, Debug, Copy)]
pub struct SpongeRoundRequest<F: SmallField, const SWIDTH: usize> {
pub initial_state: [Num<F>; SWIDTH],
Expand Down
7 changes: 6 additions & 1 deletion src/gadgets/recursion/allocated_vk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ where
}
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> CircuitVarLengthEncodable<F>
for AllocatedVerificationKey<F, H>
Expand All @@ -123,7 +123,11 @@ impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> CircuitVarLengthEncodable
el.encode_to_buffer(cs, dst);
}
}
}

impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> WitnessVarLengthEncodable<F>
for AllocatedVerificationKey<F, H>
{
fn witness_encoding_length(witness: &Self::Witness) -> usize {
let cap_size = witness.setup_merkle_tree_cap.len();
assert!(cap_size > 0);
Expand All @@ -138,3 +142,4 @@ impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> CircuitVarLengthEncodable
}
}
}

5 changes: 4 additions & 1 deletion src/gadgets/recursion/recursive_tree_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ pub trait CircuitTreeHasher<F: SmallField, B: Sized + CSAllocatable<F>>:
+ Eq
+ std::fmt::Debug
+ CSAllocatable<F>
+ CircuitVarLengthEncodable<F>;
+ CircuitVarLengthEncodable<F>
+ WitnessVarLengthEncodable<F>;

fn new<CS: ConstraintSystem<F>>(cs: &mut CS) -> Self;
fn placeholder_output<CS: ConstraintSystem<F>>(cs: &mut CS) -> Self::CircuitOutput;
Expand Down Expand Up @@ -87,6 +88,8 @@ pub trait RecursiveTreeHasher<F: SmallField, B: Sized + CSAllocatable<F>>:

use crate::gadgets::round_function::CircuitSimpleAlgebraicSponge;

use super::traits::encodable::WitnessVarLengthEncodable;

impl<
F: SmallField,
const AW: usize,
Expand Down
16 changes: 16 additions & 0 deletions src/gadgets/traits/encodable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ pub trait CircuitEncodable<F: SmallField, const N: usize>:
'static + Send + Sync + CSAllocatable<F>
{
fn encode<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> [Variable; N];
}

pub trait WitnessEncodable<F: SmallField, const N: usize>:
'static + Send + Sync + CSAllocatable<F>
{
fn encode_witness(witness: &Self::Witness, dst: &mut Vec<F>);
}

Expand All @@ -20,7 +25,11 @@ pub trait CircuitVarLengthEncodable<F: SmallField>:
{
fn encoding_length(&self) -> usize;
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, cs: &mut CS, dst: &mut Vec<Variable>);
}

pub trait WitnessVarLengthEncodable<F: SmallField>:
'static + Send + Sync + CSAllocatable<F>
{
fn witness_encoding_length(witness: &Self::Witness) -> usize;
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>);
}
Expand Down Expand Up @@ -52,7 +61,11 @@ impl<F: SmallField, const N: usize, T: CircuitVarLengthEncodable<F>> CircuitVarL
el.encode_to_buffer(cs, dst);
}
}
}

impl<F: SmallField, const N: usize, T: WitnessVarLengthEncodable<F>> WitnessVarLengthEncodable<F>
for [T; N]
{
fn witness_encoding_length(witness: &Self::Witness) -> usize {
debug_assert!(N > 0);

Expand All @@ -74,6 +87,9 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for () {
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, _cs: &mut CS, _dst: &mut Vec<Variable>) {
// do nothing
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for () {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
0
Expand Down
5 changes: 4 additions & 1 deletion src/gadgets/u16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ impl<F: SmallField> UInt16<F> {
}
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt16<F> {
#[inline(always)]
Expand All @@ -569,6 +569,9 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt16<F> {
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, _cs: &mut CS, dst: &mut Vec<Variable>) {
dst.push(self.variable);
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for UInt16<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
1
Expand Down
5 changes: 4 additions & 1 deletion src/gadgets/u160/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ use crate::gadgets::traits::selectable::MultiSelectable;
// so we degrade to default impl via normal select
impl<F: SmallField> MultiSelectable<F> for UInt160<F> {}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt160<F> {
#[inline(always)]
Expand All @@ -255,6 +255,9 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt160<F> {
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, cs: &mut CS, dst: &mut Vec<Variable>) {
CircuitVarLengthEncodable::<F>::encode_to_buffer(&self.inner, cs, dst);
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for UInt160<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
5
Expand Down
5 changes: 4 additions & 1 deletion src/gadgets/u256/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ use crate::gadgets::traits::selectable::MultiSelectable;
// so we degrade to default impl via normal select
impl<F: SmallField> MultiSelectable<F> for UInt256<F> {}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt256<F> {
#[inline(always)]
Expand All @@ -413,6 +413,9 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt256<F> {
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, cs: &mut CS, dst: &mut Vec<Variable>) {
CircuitVarLengthEncodable::<F>::encode_to_buffer(&self.inner, cs, dst);
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for UInt256<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
8
Expand Down
Loading

0 comments on commit 0dd8673

Please sign in to comment.