Skip to content
This repository has been archived by the owner on Aug 15, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into crates.io-0.2.0-branch
Browse files Browse the repository at this point in the history
  • Loading branch information
popzxc authored Jul 31, 2024
2 parents 57b7ae5 + 2661f5a commit 544e350
Show file tree
Hide file tree
Showing 16 changed files with 282 additions and 14 deletions.
7 changes: 7 additions & 0 deletions cs_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod selectable;
pub(crate) mod utils;
mod var_length_encodable;
mod witness_hook;
mod witness_var_length_encodable;

#[proc_macro_derive(CSSelectable, attributes(CSSelectableBound))]
#[proc_macro_error::proc_macro_error]
Expand All @@ -31,3 +32,9 @@ pub fn derive_witness_hook(input: TokenStream) -> TokenStream {
pub fn derive_var_length_encodable(input: TokenStream) -> TokenStream {
self::var_length_encodable::derive_var_length_encodable(input)
}

#[proc_macro_derive(WitVarLengthEncodable, attributes(WitnessVarLengthEncodableBound))]
#[proc_macro_error::proc_macro_error]
pub fn derive_witness_length_encodable(input: TokenStream) -> TokenStream {
self::witness_var_length_encodable::derive_witness_var_length_encodable(input)
}
2 changes: 1 addition & 1 deletion cs_derive/src/var_length_encodable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(crate) fn derive_var_length_encodable(
};
field_impls.extend(field_impl);
}
Type::Path(_) => {
Type::Path(ref _path_ty) => {
let field_impl = quote! {
total_len += CircuitVarLengthEncodable::<F>::encoding_length(&self.#field_ident);
};
Expand Down
95 changes: 95 additions & 0 deletions cs_derive/src/witness_var_length_encodable/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use proc_macro2::{Span, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{parse_macro_input, token::Comma, DeriveInput, GenericParam, Type, WhereClause};

use crate::utils::*;

const BOUND_ATTR_NAME: &'static str = "WitnessVarLengthEncodableBound";

pub(crate) fn derive_witness_var_length_encodable(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let derived_input = parse_macro_input!(input as DeriveInput);
let DeriveInput {
ident,
generics,
data,
attrs,
..
} = derived_input.clone();

let mut witness_to_buffer_impls = TokenStream::new();
let mut witness_length_impls = TokenStream::new();

let extra_bound = if let Some(bound) = fetch_attr_from_list(BOUND_ATTR_NAME, &attrs) {
let bound = syn::parse_str::<WhereClause>(&bound).expect("must parse bound as WhereClause");

Some(bound)
} else {
None
};

let bound = merge_where_clauses(generics.where_clause.clone(), extra_bound);

match data {
syn::Data::Struct(ref struct_data) => match struct_data.fields {
syn::Fields::Named(ref named_fields) => {
for field in named_fields.named.iter() {
let field_ident = field.ident.clone().expect("should have a field elem ident");
match field.ty {
Type::Array(ref array_ty) => {
let wit_to_buf_impl = quote! {
<#array_ty as WitnessVarLengthEncodable<F>>::encode_witness_to_buffer(&witness.#field_ident, dst);
};
witness_to_buffer_impls.extend(wit_to_buf_impl);
let wit_length_impl = quote! {
total_len += <#array_ty as WitnessVarLengthEncodable<F>>::witness_encoding_length(&witness.#field_ident);
};
witness_length_impls.extend(wit_length_impl);
}
Type::Path(ref path_ty) => {
let wit_to_buf_impl = quote! {
<#path_ty as WitnessVarLengthEncodable<F>>::encode_witness_to_buffer(&witness.#field_ident, dst);
};
witness_to_buffer_impls.extend(wit_to_buf_impl);
let wit_length_impl = quote! {
total_len += <#path_ty as WitnessVarLengthEncodable<F>>::witness_encoding_length(&witness.#field_ident);
};
witness_length_impls.extend(wit_length_impl);
}
_ => abort_call_site!("only array and path types are allowed"),
};
}
}
_ => abort_call_site!("only named fields are allowed!"),
},
_ => abort_call_site!("only struct types are allowed!"),
}

let comma = Comma(Span::call_site());

let field_generic_param = syn::parse_str::<GenericParam>(&"F: SmallField").unwrap();
let has_engine_param = has_proper_small_field_parameter(&generics.params, &field_generic_param);
if has_engine_param == false {
panic!("Expected to have `F: SmallField` somewhere in bounds");
}

let type_params_of_allocated_struct = get_type_params_from_generics(&generics, &comma);

let expanded = quote! {
impl #generics WitnessVarLengthEncodable<F> for #ident<#type_params_of_allocated_struct> #bound {
fn witness_encoding_length(witness: &Self::Witness) -> usize {
let mut total_len = 0;
#witness_length_impls

total_len
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
#witness_to_buffer_impls
}
}
};

proc_macro::TokenStream::from(expanded)
}
12 changes: 11 additions & 1 deletion src/gadgets/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ impl<F: SmallField> Boolean<F> {
}
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for Boolean<F> {
#[inline(always)]
Expand All @@ -713,3 +713,13 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for Boolean<F> {
dst.push(self.variable);
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for Boolean<F> {
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
1
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
let val = F::from_raw_u64_unchecked(*witness as u64);
dst.push(val);
}
}
12 changes: 11 additions & 1 deletion src/gadgets/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ pub fn dot_product_using_dot_product_gate<F: SmallField, CS: ConstraintSystem<F>
result
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for Num<F> {
#[inline(always)]
Expand All @@ -1226,6 +1226,16 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for Num<F> {
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for Num<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
1
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
dst.push(*witness);
}
}

use crate::gadgets::traits::allocatable::CSAllocatableExt;

impl<F: SmallField> CSAllocatableExt<F> for Num<F> {
Expand Down
20 changes: 17 additions & 3 deletions src/gadgets/queue/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::boolean::Boolean;
use super::num::Num;
use super::traits::allocatable::*;
use super::u32::UInt32;
use super::{traits::encodable::CircuitEncodable, *};
use super::{traits::encodable::CircuitEncodable, traits::encodable::WitnessVarLengthEncodable, *};
use crate::algebraic_props::round_function::AbsorptionModeOverwrite;
use crate::algebraic_props::round_function::AlgebraicRoundFunction;
use crate::config::CSConfig;
Expand Down Expand Up @@ -582,14 +582,28 @@ pub fn simulate_new_tail<
use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::serde_utils::BigArraySerde;

#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)]
#[derive(
Derivative,
CSAllocatable,
CSSelectable,
CSVarLengthEncodable,
WitVarLengthEncodable,
WitnessHookable,
)]
#[derivative(Clone, Copy, Debug)]
pub struct QueueState<F: SmallField, const N: usize> {
pub head: [Num<F>; N],
pub tail: QueueTailState<F, N>,
}

#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)]
#[derive(
Derivative,
CSAllocatable,
CSSelectable,
CSVarLengthEncodable,
WitVarLengthEncodable,
WitnessHookable,
)]
#[derivative(Clone, Copy, Debug)]
pub struct QueueTailState<F: SmallField, const N: usize> {
pub tail: [Num<F>; N],
Expand Down
9 changes: 8 additions & 1 deletion src/gadgets/queue/queue_optimizer/sponge_optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use super::*;

#[derive(Derivative, CSAllocatable, CSSelectable, CSVarLengthEncodable, WitnessHookable)]
#[derive(
Derivative,
CSAllocatable,
CSSelectable,
CSVarLengthEncodable,
WitVarLengthEncodable,
WitnessHookable,
)]
#[derivative(Clone, Debug, Copy)]
pub struct SpongeRoundRequest<F: SmallField, const SWIDTH: usize> {
pub initial_state: [Num<F>; SWIDTH],
Expand Down
20 changes: 19 additions & 1 deletion src/gadgets/recursion/allocated_vk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ where
}
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> CircuitVarLengthEncodable<F>
for AllocatedVerificationKey<F, H>
Expand All @@ -124,3 +124,21 @@ impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> CircuitVarLengthEncodable
}
}
}

impl<F: SmallField, H: RecursiveTreeHasher<F, Num<F>>> WitnessVarLengthEncodable<F>
for AllocatedVerificationKey<F, H>
{
fn witness_encoding_length(witness: &Self::Witness) -> usize {
let cap_size = witness.setup_merkle_tree_cap.len();
assert!(cap_size > 0);
let el_size = H::CircuitOutput::witness_encoding_length(&witness.setup_merkle_tree_cap[0]);

el_size * cap_size
}

fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
for el in witness.setup_merkle_tree_cap.iter() {
H::CircuitOutput::encode_witness_to_buffer(el, dst)
}
}
}
5 changes: 4 additions & 1 deletion src/gadgets/recursion/recursive_tree_hasher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ pub trait CircuitTreeHasher<F: SmallField, B: Sized + CSAllocatable<F>>:
+ Eq
+ std::fmt::Debug
+ CSAllocatable<F>
+ CircuitVarLengthEncodable<F>;
+ CircuitVarLengthEncodable<F>
+ WitnessVarLengthEncodable<F>;

fn new<CS: ConstraintSystem<F>>(cs: &mut CS) -> Self;
fn placeholder_output<CS: ConstraintSystem<F>>(cs: &mut CS) -> Self::CircuitOutput;
Expand Down Expand Up @@ -87,6 +88,8 @@ pub trait RecursiveTreeHasher<F: SmallField, B: Sized + CSAllocatable<F>>:

use crate::gadgets::round_function::CircuitSimpleAlgebraicSponge;

use super::traits::encodable::WitnessVarLengthEncodable;

impl<
F: SmallField,
const AW: usize,
Expand Down
39 changes: 39 additions & 0 deletions src/gadgets/traits/encodable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ pub trait CircuitEncodable<F: SmallField, const N: usize>:
fn encode<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> [Variable; N];
}

pub trait WitnessEncodable<F: SmallField, const N: usize>:
'static + Send + Sync + CSAllocatable<F>
{
fn encode_witness(witness: &Self::Witness, dst: &mut Vec<F>);
}

pub trait CircuitEncodableExt<F: SmallField, const N: usize>:
CircuitEncodable<F, N> + CSAllocatableExt<F>
{
Expand All @@ -21,6 +27,13 @@ pub trait CircuitVarLengthEncodable<F: SmallField>:
fn encode_to_buffer<CS: ConstraintSystem<F>>(&self, cs: &mut CS, dst: &mut Vec<Variable>);
}

pub trait WitnessVarLengthEncodable<F: SmallField>:
'static + Send + Sync + CSAllocatable<F>
{
fn witness_encoding_length(witness: &Self::Witness) -> usize;
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>);
}

// unfortunately default implementation is impossible as compiler can not have constraint "for all N"

// impl<F: SmallField, const N: usize, T: CircuitEncodable<F, N>> CircuitVarLengthEncodable<F> for T {
Expand Down Expand Up @@ -50,6 +63,22 @@ impl<F: SmallField, const N: usize, T: CircuitVarLengthEncodable<F>> CircuitVarL
}
}

impl<F: SmallField, const N: usize, T: WitnessVarLengthEncodable<F>> WitnessVarLengthEncodable<F>
for [T; N]
{
fn witness_encoding_length(witness: &Self::Witness) -> usize {
debug_assert!(N > 0);

N * T::witness_encoding_length(&witness[0])
}

fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
for el in witness.iter() {
T::encode_witness_to_buffer(el, dst);
}
}
}

impl<F: SmallField> CircuitVarLengthEncodable<F> for () {
#[inline(always)]
fn encoding_length(&self) -> usize {
Expand All @@ -59,3 +88,13 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for () {
// do nothing
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for () {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
0
}
fn encode_witness_to_buffer(_witness: &Self::Witness, _dst: &mut Vec<F>) {
// do nothing
}
}
12 changes: 11 additions & 1 deletion src/gadgets/u16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ impl<F: SmallField> UInt16<F> {
}
}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt16<F> {
#[inline(always)]
Expand All @@ -570,3 +570,13 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt16<F> {
dst.push(self.variable);
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for UInt16<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
1
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
dst.push(F::from_u64_with_reduction(*witness as u64));
}
}
13 changes: 12 additions & 1 deletion src/gadgets/u160/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ use crate::gadgets::traits::selectable::MultiSelectable;
// so we degrade to default impl via normal select
impl<F: SmallField> MultiSelectable<F> for UInt160<F> {}

use crate::gadgets::traits::encodable::CircuitVarLengthEncodable;
use crate::gadgets::traits::encodable::{CircuitVarLengthEncodable, WitnessVarLengthEncodable};

impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt160<F> {
#[inline(always)]
Expand All @@ -257,6 +257,17 @@ impl<F: SmallField> CircuitVarLengthEncodable<F> for UInt160<F> {
}
}

impl<F: SmallField> WitnessVarLengthEncodable<F> for UInt160<F> {
#[inline(always)]
fn witness_encoding_length(_witness: &Self::Witness) -> usize {
5
}
fn encode_witness_to_buffer(witness: &Self::Witness, dst: &mut Vec<F>) {
let chunks = decompose_address_as_u32x5(*witness);
chunks.map(|el| UInt32::encode_witness_to_buffer(&el, dst));
}
}

use crate::gadgets::traits::allocatable::CSPlaceholder;

impl<F: SmallField> CSPlaceholder<F> for UInt160<F> {
Expand Down
Loading

0 comments on commit 544e350

Please sign in to comment.