Skip to content

Commit

Permalink
Refactoring starks and expressions in C++ class
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerTaule committed Sep 6, 2024
1 parent 18b7b46 commit 23329ee
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 134 deletions.
7 changes: 2 additions & 5 deletions common/src/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{os::raw::c_void, path::Path};

use log::trace;

use proofman_starks_lib_c::{const_pols_new_c, expressions_bin_new_c, expressions_ctx_new_c, setup_ctx_new_c, stark_info_new_c};
use proofman_starks_lib_c::{const_pols_new_c, expressions_bin_new_c, setup_ctx_new_c, stark_info_new_c};

use crate::GlobalInfo;

Expand All @@ -13,7 +13,6 @@ pub struct Setup {
pub air_id: usize,
pub p_setup: *mut c_void,
pub p_stark_info: *mut c_void,
pub p_expressions: *mut c_void,
}

impl Setup {
Expand Down Expand Up @@ -46,8 +45,6 @@ impl Setup {

let p_setup = setup_ctx_new_c(p_stark_info, p_expressions_bin, p_const_pols);

let p_expressions = expressions_ctx_new_c(p_setup);

Self { air_id, air_group_id, p_setup, p_expressions, p_stark_info }
Self { air_id, air_group_id, p_setup, p_stark_info }
}
}
18 changes: 9 additions & 9 deletions hints/src/hints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ impl HintCol {
}
}

pub fn get_hint_ids_by_name(p_expressions: *mut c_void, name: &str) -> Vec<u64> {
let raw_ptr = get_hint_ids_by_name_c(p_expressions, name);
pub fn get_hint_ids_by_name(p_setup: *mut c_void, name: &str) -> Vec<u64> {
let raw_ptr = get_hint_ids_by_name_c(p_setup, name);

let hint_ids_result = unsafe { Box::from_raw(raw_ptr as *mut HintIdsResult) };

Expand All @@ -246,7 +246,7 @@ pub fn get_hint_field<F: Clone + Copy>(

let setup = setup_ctx.get_setup(air_instance_ctx.air_group_id, air_instance_ctx.air_id).expect("REASON");

let raw_ptr = get_hint_field_c(setup.p_expressions, params, hint_id as u64, hint_field_name, dest, inverse, print_expression);
let raw_ptr = get_hint_field_c(setup.p_setup, params, hint_id as u64, hint_field_name, dest, inverse, print_expression);

let hint_field = unsafe { Box::from_raw(raw_ptr as *mut HintFieldInfo<F>) };

Expand All @@ -265,7 +265,7 @@ pub fn get_hint_field_constant<F: Clone + Copy>(

let setup = setup_ctx.get_setup(air_group_id, air_id).expect("REASON");

let raw_ptr = get_hint_field_c(setup.p_expressions, std::ptr::null_mut(), hint_id as u64, hint_field_name, dest, false, print_expression);
let raw_ptr = get_hint_field_c(setup.p_setup, std::ptr::null_mut(), hint_id as u64, hint_field_name, dest, false, print_expression);

let hint_field = unsafe { Box::from_raw(raw_ptr as *mut HintFieldInfo<F>) };

Expand All @@ -290,7 +290,7 @@ pub fn set_hint_field<F: Copy + core::fmt::Debug>(
_ => panic!("Only column and column extended are accepted"),
};

let id = set_hint_field_c(setup.p_expressions, params, values_ptr, hint_id, hint_field_name);
let id = set_hint_field_c(setup.p_setup, params, values_ptr, hint_id, hint_field_name);

air_instance_ctx.set_commit_calculated(id as usize);
}
Expand Down Expand Up @@ -321,7 +321,7 @@ pub fn set_hint_field_val<F: Clone + Copy + std::fmt::Debug>(

let values_ptr = value_array.as_mut_ptr() as *mut c_void;

let id = set_hint_field_c(setup.p_expressions, params, values_ptr, hint_id, hint_field_name);
let id = set_hint_field_c(setup.p_setup, params, values_ptr, hint_id, hint_field_name);

air_instance_ctx.set_subproofvalue_calculated(id as usize);
}
Expand All @@ -337,10 +337,10 @@ pub fn print_expression<F: Clone + Copy + Debug>(

match expr {
HintFieldValue::Column(vec) => {
print_expression_c(setup.p_expressions, vec.as_ptr() as *mut c_void, 1, first_print_value, last_print_value);
print_expression_c(setup.p_setup, vec.as_ptr() as *mut c_void, 1, first_print_value, last_print_value);
}
HintFieldValue::ColumnExtended(vec) => {
print_expression_c(setup.p_expressions, vec.as_ptr() as *mut c_void, 3, first_print_value, last_print_value);
print_expression_c(setup.p_setup, vec.as_ptr() as *mut c_void, 3, first_print_value, last_print_value);
}
HintFieldValue::Field(val) => {
println!("Field value: {:?}", val);
Expand All @@ -367,7 +367,7 @@ pub fn print_by_name<F: Clone + Copy>(
let lengths_ptr = lengths.as_ref().map(|lengths| lengths.clone().as_mut_ptr()).unwrap_or(std::ptr::null_mut());

// TODO: CHECK WHAT IS WRONG WITH RETURN VALUES
let _raw_ptr = print_by_name_c(setup.p_expressions, params, name, lengths_ptr, first_print_value, last_print_value, false);
let _raw_ptr = print_by_name_c(setup.p_setup, params, name, lengths_ptr, first_print_value, last_print_value, false);

// if return_values {
// let field = unsafe { Box::from_raw(raw_ptr as *mut HintFieldInfo<F>) };
Expand Down
17 changes: 8 additions & 9 deletions provers/stark/src/stark_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ pub struct StarkProver<T: Field> {
air_group_id: usize,
config: StarkProverSettings,
p_setup: *mut c_void,
p_expressions: *mut c_void,
pub p_stark: *mut c_void,
p_stark_info: *mut c_void,
stark_info: StarkInfo,
Expand Down Expand Up @@ -67,10 +66,9 @@ impl<T: Field> StarkProver<T> {
let setup = sctx.get_setup(air_group_id, air_id).expect("REASON");

let p_setup = setup.p_setup;
let p_expressions = setup.p_expressions;
let p_stark_info = setup.p_stark_info;

let p_stark = starks_new_c(p_setup, p_expressions);
let p_stark = starks_new_c(p_setup);

let stark_info_path = base_filename_path.clone() + ".starkinfo.json";
let stark_info_json = std::fs::read_to_string(&stark_info_path)
Expand Down Expand Up @@ -104,7 +102,6 @@ impl<T: Field> StarkProver<T> {
air_group_id,
config,
p_setup,
p_expressions,
p_stark_info,
p_stark,
p_proof: None,
Expand Down Expand Up @@ -179,7 +176,7 @@ impl<F: Field> Prover<F> for StarkProver<F> {
fn verify_constraints(&self, proof_ctx: &mut ProofCtx<F>) -> Vec<ConstraintInfo> {
let air_instance_ctx = &mut proof_ctx.air_instances.write().unwrap()[self.prover_idx];

let raw_ptr = verify_constraints_c(self.p_expressions, air_instance_ctx.params.unwrap());
let raw_ptr = verify_constraints_c(self.p_setup, air_instance_ctx.params.unwrap());

let constraints_result = unsafe { Box::from_raw(raw_ptr as *mut ConstraintsResults) };

Expand All @@ -204,7 +201,7 @@ impl<F: Field> Prover<F> for StarkProver<F> {
panic!("Intermediate polynomials for stage {} cannot be calculated: Witness column {} is not calculated", stage_id, cm_pol.name);
}
}
calculate_impols_expressions_c(self.p_expressions, air_instance_ctx.params.unwrap(), stage_id as u64);
calculate_impols_expressions_c(self.p_stark, air_instance_ctx.params.unwrap(), stage_id as u64);
for i in 0..n_commits {
let cm_pol = self.stark_info.cm_pols_map.as_ref().expect("REASON").get(i).unwrap();
if cm_pol.stage == stage_id as u64 && cm_pol.im_pol {
Expand All @@ -216,7 +213,7 @@ impl<F: Field> Prover<F> for StarkProver<F> {
fri_proof_set_subproof_values_c(p_proof, air_instance_ctx.params.unwrap());
}
} else {
calculate_quotient_polynomial_c(self.p_expressions, air_instance_ctx.params.unwrap());
calculate_quotient_polynomial_c(self.p_stark, air_instance_ctx.params.unwrap());
for i in 0..n_commits {
let cm_pol: &crate::stark_info::PolMap = self.stark_info.cm_pols_map.as_ref().expect("REASON").get(i).unwrap();
if cm_pol.stage == (proof_ctx.pilout.num_stages() + 1) as u64 {
Expand Down Expand Up @@ -394,7 +391,9 @@ impl<F: Field> StarkProver<F> {

debug!("{}: ··· Computing FRI Polynomial", Self::MY_NAME);

compute_fri_pol_c(p_stark, self.stark_info.n_stages as u64 + 2, air_instance_ctx.params.unwrap());
prepare_fri_polynomial_c(p_stark, air_instance_ctx.params.unwrap());

calculate_fri_polynomial_c(p_stark, air_instance_ctx.params.unwrap());
}

fn compute_fri_folding(&mut self, opening_id: u32, proof_ctx: &mut ProofCtx<F>, transcript: &FFITranscript) {
Expand Down Expand Up @@ -425,7 +424,7 @@ impl<F: Field> StarkProver<F> {
} else {
let hash: Vec<F> = vec![F::zero(); self.n_field_elements];
let n_hash = (1 << (steps[n_steps - 1].n_bits)) * Self::FIELD_EXTENSION as u64;
let fri_pol = get_fri_pol_c(self.p_expressions, air_instance_ctx.params.unwrap());
let fri_pol = get_fri_pol_c(self.p_setup, air_instance_ctx.params.unwrap());
calculate_hash_c(p_stark, hash.as_ptr() as *mut c_void, fri_pol, n_hash);
transcript.add_elements(hash.as_ptr() as *mut c_void, self.n_field_elements);
}
Expand Down
101 changes: 49 additions & 52 deletions provers/starks-lib-c/bindings_starks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ extern "C" {
p_const_pols: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z20get_hint_ids_by_namePvPc"]
pub fn get_hint_ids_by_name(
pSetupCtx: *mut ::std::os::raw::c_void,
hintName: *mut ::std::os::raw::c_char,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z14setup_ctx_freePv"]
pub fn setup_ctx_free(pSetupCtx: *mut ::std::os::raw::c_void);
Expand Down Expand Up @@ -100,36 +107,19 @@ extern "C" {
pub fn expressions_bin_free(pExpressionsBin: *mut ::std::os::raw::c_void);
}
extern "C" {
#[link_name = "\u{1}_Z19expressions_ctx_newPv"]
pub fn expressions_ctx_new(
pSetupCtx: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z18verify_constraintsPvS_"]
pub fn verify_constraints(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z11get_fri_polPvS_"]
pub fn get_fri_pol(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z20get_hint_ids_by_namePvPc"]
pub fn get_hint_ids_by_name(
pExpressionsCtx: *mut ::std::os::raw::c_void,
hintName: *mut ::std::os::raw::c_char,
#[link_name = "\u{1}_Z11init_paramsPvS_S_S_S_"]
pub fn init_params(
ptr: *mut ::std::os::raw::c_void,
public_inputs: *mut ::std::os::raw::c_void,
challenges: *mut ::std::os::raw::c_void,
evals: *mut ::std::os::raw::c_void,
subproofValues: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z14get_hint_fieldPvS_mPcbbb"]
pub fn get_hint_field(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pSetupCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
hintId: u64,
hintFieldName: *mut ::std::os::raw::c_char,
Expand All @@ -141,34 +131,35 @@ extern "C" {
extern "C" {
#[link_name = "\u{1}_Z14set_hint_fieldPvS_S_mPc"]
pub fn set_hint_field(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pSetupCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
values: *mut ::std::os::raw::c_void,
hintId: u64,
hintFieldName: *mut ::std::os::raw::c_char,
) -> u64;
}
extern "C" {
#[link_name = "\u{1}_Z20expressions_ctx_freePv"]
pub fn expressions_ctx_free(pExpressionsCtx: *mut ::std::os::raw::c_void);
}
extern "C" {
#[link_name = "\u{1}_Z11init_paramsPvS_S_S_S_"]
pub fn init_params(
ptr: *mut ::std::os::raw::c_void,
public_inputs: *mut ::std::os::raw::c_void,
challenges: *mut ::std::os::raw::c_void,
evals: *mut ::std::os::raw::c_void,
subproofValues: *mut ::std::os::raw::c_void,
#[link_name = "\u{1}_Z11get_fri_polPvS_"]
pub fn get_fri_pol(
pSetupCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z10starks_newPvS_"]
pub fn starks_new(
#[link_name = "\u{1}_Z18verify_constraintsPvS_"]
pub fn verify_constraints(
pSetupCtx: *mut ::std::os::raw::c_void,
pExpressionsCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z11params_freePv"]
pub fn params_free(pParams: *mut ::std::os::raw::c_void) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z10starks_newPv"]
pub fn starks_new(pSetupCtx: *mut ::std::os::raw::c_void) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z11starks_freePv"]
pub fn starks_free(pStarks: *mut ::std::os::raw::c_void);
Expand All @@ -190,17 +181,31 @@ extern "C" {
root: *mut ::std::os::raw::c_void,
);
}
extern "C" {
#[link_name = "\u{1}_Z15prepare_fri_polPvS_"]
pub fn prepare_fri_pol(
pStarks: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z24calculate_fri_polynomialPvS_"]
pub fn calculate_fri_polynomial(
pStarks: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
);
}
extern "C" {
#[link_name = "\u{1}_Z29calculate_quotient_polynomialPvS_"]
pub fn calculate_quotient_polynomial(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pStarks: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
);
}
extern "C" {
#[link_name = "\u{1}_Z28calculate_impols_expressionsPvS_m"]
pub fn calculate_impols_expressions(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pStarks: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
step: u64,
);
Expand All @@ -223,14 +228,6 @@ extern "C" {
pProof: *mut ::std::os::raw::c_void,
);
}
extern "C" {
#[link_name = "\u{1}_Z15compute_fri_polPvmS_"]
pub fn compute_fri_pol(
pStarks: *mut ::std::os::raw::c_void,
step: u64,
pParams: *mut ::std::os::raw::c_void,
) -> *mut ::std::os::raw::c_void;
}
extern "C" {
#[link_name = "\u{1}_Z19compute_fri_foldingPvmS_S_S_"]
pub fn compute_fri_folding(
Expand Down Expand Up @@ -315,7 +312,7 @@ extern "C" {
extern "C" {
#[link_name = "\u{1}_Z13print_by_namePvS_PcPmmmb"]
pub fn print_by_name(
pExpressionsCtx: *mut ::std::os::raw::c_void,
pSetupCtx: *mut ::std::os::raw::c_void,
pParams: *mut ::std::os::raw::c_void,
name: *mut ::std::os::raw::c_char,
lengths: *mut u64,
Expand All @@ -327,7 +324,7 @@ extern "C" {
extern "C" {
#[link_name = "\u{1}_Z16print_expressionPvS_mmm"]
pub fn print_expression(
pExpressionCtx: *mut ::std::os::raw::c_void,
pSetupCtx: *mut ::std::os::raw::c_void,
pol: *mut ::std::os::raw::c_void,
dim: u64,
first_value: u64,
Expand Down
Loading

0 comments on commit 23329ee

Please sign in to comment.