From 48223cca11fa98baa62f1e7cd663b6e8b2d565a0 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Wed, 3 Apr 2024 02:26:50 +0100 Subject: [PATCH] fix: make commitment optional for backwards compat (#762) --- src/commands.rs | 6 ++-- src/execute.rs | 58 +++++++++++++++++++------------------ src/lib.rs | 10 +++++-- src/python.rs | 12 +++++++- src/tensor/val.rs | 7 ++--- src/wasm.rs | 7 +++-- tests/integration_tests.rs | 4 +-- tests/wasm/model.compiled | Bin 1724 -> 1721 bytes tests/wasm/settings.json | 3 +- 9 files changed, 63 insertions(+), 44 deletions(-) diff --git a/src/commands.rs b/src/commands.rs index 5c7cb0f44..28ee046ee 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -444,7 +444,7 @@ pub enum Commands { disable_selector_compression: bool, /// commitment used #[arg(long, default_value = DEFAULT_COMMITMENT)] - commitment: Commitments, + commitment: Option, }, /// Aggregates proofs :) Aggregate { @@ -479,7 +479,7 @@ pub enum Commands { split_proofs: bool, /// commitment used #[arg(long, default_value = DEFAULT_COMMITMENT)] - commitment: Commitments, + commitment: Option, }, /// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements CompileCircuit { @@ -726,7 +726,7 @@ pub enum Commands { logrows: u32, /// commitment #[arg(long, default_value = DEFAULT_COMMITMENT)] - commitment: Commitments, + commitment: Option, }, #[cfg(not(target_arch = "wasm32"))] /// Deploys an evm verifier that is generated by ezkl diff --git a/src/execute.rs b/src/execute.rs index c6cf094ac..dcc5e3c68 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -339,7 +339,7 @@ pub async fn run(command: Commands) -> Result> { logrows, split_proofs, disable_selector_compression, - commitment, + commitment.into(), ), Commands::Aggregate { proof_path, @@ -360,7 +360,7 @@ pub async fn run(command: Commands) -> Result> { logrows, check_mode, split_proofs, - commitment, + commitment.into(), ) .map(|e| serde_json::to_string(&e).unwrap()), Commands::Verify { @@ -384,7 +384,7 @@ pub async fn run(command: Commands) -> Result> { srs_path, logrows, reduced_srs, - commitment, + commitment.into(), ) .map(|e| serde_json::to_string(&e).unwrap()), #[cfg(not(target_arch = "wasm32"))] @@ -586,7 +586,7 @@ pub(crate) async fn get_srs_cmd( } else if let Some(settings_p) = settings_path { if settings_p.exists() { let settings = GraphSettings::load(&settings_p)?; - settings.run_args.commitment + settings.run_args.commitment.into() } else { return Err(err_string.into()); } @@ -666,21 +666,17 @@ pub(crate) async fn gen_witness( // if any of the settings have kzg visibility then we need to load the srs + let commitment: Commitments = settings.run_args.commitment.into(); + let start_time = Instant::now(); let witness = if settings.module_requires_polycommit() { - if get_srs_path( - settings.run_args.logrows, - srs_path.clone(), - settings.run_args.commitment, - ) - .exists() - { - match settings.run_args.commitment { + if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() { + match Commitments::from(settings.run_args.commitment) { Commitments::KZG => { let srs: ParamsKZG = load_params_prover::>( srs_path.clone(), settings.run_args.logrows, - settings.run_args.commitment, + commitment, )?; circuit.forward::>( &mut input, @@ -694,7 +690,7 @@ pub(crate) async fn gen_witness( load_params_prover::>( srs_path.clone(), settings.run_args.logrows, - settings.run_args.commitment, + commitment, )?; circuit.forward::>( &mut input, @@ -1303,17 +1299,19 @@ pub(crate) fn create_evm_verifier( render_vk_seperately: bool, ) -> Result> { check_solc_requirement(); - let circuit_settings = GraphSettings::load(&settings_path)?; + + let settings = GraphSettings::load(&settings_path)?; + let commitment: Commitments = settings.run_args.commitment.into(); let params = load_params_verifier::>( srs_path, - circuit_settings.run_args.logrows, - circuit_settings.run_args.commitment, + settings.run_args.logrows, + commitment, )?; - let num_instance = circuit_settings.total_instances(); + let num_instance = settings.total_instances(); let num_instance: usize = num_instance.iter().sum::(); - let vk = load_vk::, GraphCircuit>(vk_path, circuit_settings)?; + let vk = load_vk::, GraphCircuit>(vk_path, settings)?; trace!("params computed"); let generator = halo2_solidity_verifier::SolidityGenerator::new( @@ -1347,17 +1345,18 @@ pub(crate) fn create_evm_vk( abi_path: PathBuf, ) -> Result> { check_solc_requirement(); - let circuit_settings = GraphSettings::load(&settings_path)?; + let settings = GraphSettings::load(&settings_path)?; + let commitment: Commitments = settings.run_args.commitment.into(); let params = load_params_verifier::>( srs_path, - circuit_settings.run_args.logrows, - circuit_settings.run_args.commitment, + settings.run_args.logrows, + commitment, )?; - let num_instance = circuit_settings.total_instances(); + let num_instance = settings.total_instances(); let num_instance: usize = num_instance.iter().sum::(); - let vk = load_vk::, GraphCircuit>(vk_path, circuit_settings)?; + let vk = load_vk::, GraphCircuit>(vk_path, settings)?; trace!("params computed"); let generator = halo2_solidity_verifier::SolidityGenerator::new( @@ -1626,8 +1625,9 @@ pub(crate) fn setup( } let logrows = circuit.settings().run_args.logrows; + let commitment: Commitments = circuit.settings().run_args.commitment.into(); - let pk = match circuit.settings().run_args.commitment { + let pk = match commitment { Commitments::KZG => { let params = load_params_prover::>( srs_path, @@ -1736,7 +1736,8 @@ pub(crate) fn prove( let transcript: TranscriptType = proof_type.into(); let proof_split_commits: Option = data.into(); - let commitment = circuit_settings.run_args.commitment; + let commitment = circuit_settings.run_args.commitment.into(); + let logrows = circuit_settings.run_args.logrows; // creates and verifies the proof let mut snark = match commitment { Commitments::KZG => { @@ -1745,7 +1746,7 @@ pub(crate) fn prove( let params = load_params_prover::>( srs_path, - circuit_settings.run_args.logrows, + logrows, Commitments::KZG, )?; match strategy { @@ -2187,8 +2188,9 @@ pub(crate) fn verify( let circuit_settings = GraphSettings::load(&settings_path)?; let logrows = circuit_settings.run_args.logrows; + let commitment = circuit_settings.run_args.commitment.into(); - match circuit_settings.run_args.commitment { + match commitment { Commitments::KZG => { let proof = Snark::load::>(&proof_path)?; let params: ParamsKZG = if reduced_srs { diff --git a/src/lib.rs b/src/lib.rs index b203acfe0..96872917d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -114,6 +114,12 @@ pub enum Commitments { IPA, } +impl From> for Commitments { + fn from(value: Option) -> Self { + value.unwrap_or(Commitments::KZG) + } +} + impl FromStr for Commitments { type Err = String; fn from_str(s: &str) -> Result { @@ -214,7 +220,7 @@ pub struct RunArgs { pub check_mode: CheckMode, /// commitment scheme #[arg(long, default_value = "kzg")] - pub commitment: Commitments, + pub commitment: Option, } impl Default for RunArgs { @@ -234,7 +240,7 @@ impl Default for RunArgs { div_rebasing: false, rebase_frac_zero_constants: false, check_mode: CheckMode::UNSAFE, - commitment: Commitments::KZG, + commitment: None, } } } diff --git a/src/python.rs b/src/python.rs index 837fbdff9..82d2ee420 100644 --- a/src/python.rs +++ b/src/python.rs @@ -197,7 +197,7 @@ impl From for RunArgs { div_rebasing: py_run_args.div_rebasing, rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants, check_mode: py_run_args.check_mode, - commitment: py_run_args.commitment.into(), + commitment: Some(py_run_args.commitment.into()), } } } @@ -234,6 +234,16 @@ pub enum PyCommitments { IPA, } +impl From> for PyCommitments { + fn from(commitment: Option) -> Self { + match commitment { + Some(Commitments::KZG) => PyCommitments::KZG, + Some(Commitments::IPA) => PyCommitments::IPA, + None => PyCommitments::KZG, + } + } +} + impl From for Commitments { fn from(py_commitments: PyCommitments) -> Self { match py_commitments { diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 9c81febee..6bb1abe46 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -1,3 +1,5 @@ +use core::{iter::FilterMap, slice::Iter}; + use crate::circuit::region::ConstantsMap; use super::{ @@ -450,10 +452,7 @@ impl ValTensor { /// Returns the number of constants in the [ValTensor]. pub fn create_constants_map_iterator( &self, - ) -> core::iter::FilterMap< - core::slice::Iter<'_, ValType>, - fn(&ValType) -> Option<(F, ValType)>, - > { + ) -> FilterMap>, fn(&ValType) -> Option<(F, ValType)>> { match self { ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| { if let ValType::Constant(v) = x { diff --git a/src/wasm.rs b/src/wasm.rs index ec3e9af8c..26e1ce1e2 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -337,8 +337,10 @@ pub fn verify( let orig_n = 1 << circuit_settings.run_args.logrows; + let commitment = circuit_settings.run_args.commitment.into(); + let mut reader = std::io::BufReader::new(&srs[..]); - let result = match circuit_settings.run_args.commitment { + let result = match commitment { Commitments::KZG => { let params: ParamsKZG = halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) @@ -521,8 +523,9 @@ pub fn prove( // read in kzg params let mut reader = std::io::BufReader::new(&srs[..]); + let commitment = circuit.settings().run_args.commitment.into(); // creates and verifies the proof - let proof = match circuit.settings().run_args.commitment { + let proof = match commitment { Commitments::KZG => { let params: ParamsKZG = halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 7c1c1cd28..a8b6fa730 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -122,7 +122,7 @@ mod native_tests { let settings: GraphSettings = serde_json::from_str(&settings).unwrap(); let logrows = settings.run_args.logrows; - download_srs(logrows, settings.run_args.commitment); + download_srs(logrows, settings.run_args.commitment.into()); } fn mv_test_(test_dir: &str, test: &str) { @@ -1971,7 +1971,7 @@ mod native_tests { .expect("failed to parse settings file"); // get_srs for the graph_settings_num_instances - download_srs(1, graph_settings.run_args.commitment); + download_srs(1, graph_settings.run_args.commitment.into()); let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ diff --git a/tests/wasm/model.compiled b/tests/wasm/model.compiled index 40041e853be3a1089e68bbb5382d393e556cdff0..d36215c15add6fdcdc7978318007db1c7e032438 100644 GIT binary patch delta 12 TcmdnPyOVc=8|!9I)?!8g9y