From 678a249dcb6d7f7977e34810ab587e2ce4efcd9e Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:28:54 +0000 Subject: [PATCH] feat: allow for reduced n srs for verification (#716) --- Cargo.lock | 4 +-- examples/conv2d_mnist/main.rs | 2 ++ src/circuit/tests.rs | 55 +++++++++++++++++++++++++---------- src/commands.rs | 5 ++++ src/execute.rs | 34 ++++++++++++++++++---- src/graph/mod.rs | 8 +++++ src/graph/model.rs | 4 +-- src/pfsys/mod.rs | 9 ++++-- src/python.rs | 11 ++++++- src/wasm.rs | 19 +++++------- tests/integration_tests.rs | 24 +++++++++++++++ tests/wasm.rs | 4 +-- 12 files changed, 138 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5e21b5bc9..d8fd9f028 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2262,7 +2262,7 @@ dependencies = [ [[package]] name = "halo2_gadgets" version = "0.2.0" -source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040" +source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f" dependencies = [ "arrayvec 0.7.4", "bitvec 1.0.1", @@ -2279,7 +2279,7 @@ dependencies = [ [[package]] name = "halo2_proofs" version = "0.3.0" -source = "git+https://github.com/zkonduit/halo2?branch=main#ca603c14eb57030739b252e580a979023fa59040" +source = "git+https://github.com/zkonduit/halo2?branch=main#fe7522c85c8c434d7ceb9f663b0fb51909b9994f" dependencies = [ "blake2b_simd", "env_logger", diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index 59e60514d..21ee37f9a 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -6,6 +6,7 @@ use ezkl::fieldutils; use ezkl::fieldutils::i32_to_felt; use ezkl::tensor::*; use halo2_proofs::dev::MockProver; +use halo2_proofs::poly::commitment::Params; use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, @@ -489,6 +490,7 @@ pub fn runconv() { strategy, pi_for_real_prover, &mut transcript, + params.n(), ); assert!(verify.is_ok()); diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 5852bb590..d68bc273e 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -246,7 +246,7 @@ mod matmul_col_overflow { #[cfg(test)] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] mod matmul_col_ultra_overflow_double_col { - use halo2_proofs::poly::commitment::ParamsProver; + use halo2_proofs::poly::commitment::{Params, ParamsProver}; use super::*; @@ -349,8 +349,13 @@ mod matmul_col_ultra_overflow_double_col { let strategy = halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); let vk = pk.get_vk(); - let result = - crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + let result = crate::pfsys::verify_proof_circuit_kzg( + params.verifier_params(), + proof, + vk, + strategy, + params.n(), + ); assert!(result.is_ok()); @@ -361,7 +366,7 @@ mod matmul_col_ultra_overflow_double_col { #[cfg(test)] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] mod matmul_col_ultra_overflow { - use halo2_proofs::poly::commitment::ParamsProver; + use halo2_proofs::poly::commitment::{Params, ParamsProver}; use super::*; @@ -463,8 +468,13 @@ mod matmul_col_ultra_overflow { let strategy = halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); let vk = pk.get_vk(); - let result = - crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + let result = crate::pfsys::verify_proof_circuit_kzg( + params.verifier_params(), + proof, + vk, + strategy, + params.n(), + ); assert!(result.is_ok()); @@ -1140,7 +1150,7 @@ mod conv { #[cfg(test)] #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] mod conv_col_ultra_overflow { - use halo2_proofs::poly::commitment::ParamsProver; + use halo2_proofs::poly::commitment::{Params, ParamsProver}; use super::*; @@ -1262,8 +1272,13 @@ mod conv_col_ultra_overflow { let strategy = halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); let vk = pk.get_vk(); - let result = - crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + let result = crate::pfsys::verify_proof_circuit_kzg( + params.verifier_params(), + proof, + vk, + strategy, + params.n(), + ); assert!(result.is_ok()); @@ -1275,7 +1290,7 @@ mod conv_col_ultra_overflow { // not wasm 32 unknown #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] mod conv_relu_col_ultra_overflow { - use halo2_proofs::poly::commitment::ParamsProver; + use halo2_proofs::poly::commitment::{Params, ParamsProver}; use super::*; @@ -1412,8 +1427,13 @@ mod conv_relu_col_ultra_overflow { let strategy = halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); let vk = pk.get_vk(); - let result = - crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + let result = crate::pfsys::verify_proof_circuit_kzg( + params.verifier_params(), + proof, + vk, + strategy, + params.n(), + ); assert!(result.is_ok()); @@ -2343,7 +2363,7 @@ mod lookup_ultra_overflow { use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, plonk::{Circuit, ConstraintSystem, Error}, - poly::commitment::ParamsProver, + poly::commitment::{Params, ParamsProver}, }; #[derive(Clone)] @@ -2447,8 +2467,13 @@ mod lookup_ultra_overflow { let strategy = halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params()); let vk = pk.get_vk(); - let result = - crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, vk, strategy); + let result = crate::pfsys::verify_proof_circuit_kzg( + params.verifier_params(), + proof, + vk, + strategy, + params.n(), + ); assert!(result.is_ok()); diff --git a/src/commands.rs b/src/commands.rs index db377fc4d..d67348eec 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -85,6 +85,8 @@ pub const DEFAULT_VK_SOL: &str = "vk.sol"; pub const DEFAULT_VK_ABI: &str = "vk.abi"; /// Default scale rebase multipliers for calibration pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10"; +/// Default use reduced srs for verification +pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false"; impl std::fmt::Display for TranscriptType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -695,6 +697,9 @@ pub enum Commands { /// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs #[arg(long)] srs_path: Option, + /// Reduce SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony) + #[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)] + reduced_srs: Option, }, /// Verifies an aggregate proof, returning accept or reject VerifyAggr { diff --git a/src/execute.rs b/src/execute.rs index ab0bdcbe8..d03f77722 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -366,7 +366,8 @@ pub async fn run(command: Commands) -> Result> { settings_path, vk_path, srs_path, - } => verify(proof_path, settings_path, vk_path, srs_path) + reduced_srs, + } => verify(proof_path, settings_path, vk_path, srs_path, reduced_srs) .map(|e| serde_json::to_string(&e).unwrap()), Commands::VerifyAggr { proof_path, @@ -1714,6 +1715,7 @@ pub(crate) fn fuzz( bad_proof, pk.get_vk(), strategy.clone(), + params.n(), ) .map_err(|_| ()) }; @@ -1744,6 +1746,7 @@ pub(crate) fn fuzz( bad_proof, pk.get_vk(), strategy.clone(), + params.n(), ) .map_err(|_| ()) }; @@ -1780,6 +1783,7 @@ pub(crate) fn fuzz( proof.clone(), bad_vk, strategy.clone(), + params.n(), ) .map_err(|_| ()) }; @@ -1811,6 +1815,7 @@ pub(crate) fn fuzz( bad_proof, pk.get_vk(), strategy.clone(), + params.n(), ) .map_err(|_| ()) }; @@ -1846,6 +1851,7 @@ pub(crate) fn fuzz( bad_proof, pk.get_vk(), strategy.clone(), + params.n(), ) .map_err(|_| ()) }; @@ -2031,15 +2037,33 @@ pub(crate) fn verify( settings_path: PathBuf, vk_path: PathBuf, srs_path: Option, + reduced_srs: Option, ) -> Result> { let circuit_settings = GraphSettings::load(&settings_path)?; - let params = load_params_cmd(srs_path, circuit_settings.run_args.logrows)?; + + let params = if let Some(reduced_srs) = reduced_srs { + if reduced_srs { + load_params_cmd(srs_path, circuit_settings.log2_total_instances())? + } else { + load_params_cmd(srs_path, circuit_settings.run_args.logrows)? + } + } else { + load_params_cmd(srs_path, circuit_settings.run_args.logrows)? + }; + let proof = Snark::load::>(&proof_path)?; let strategy = KZGSingleStrategy::new(params.verifier_params()); - let vk = load_vk::, Fr, GraphCircuit>(vk_path, circuit_settings)?; + let vk = + load_vk::, Fr, GraphCircuit>(vk_path, circuit_settings.clone())?; let now = Instant::now(); - let result = verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy); + let result = verify_proof_circuit_kzg( + params.verifier_params(), + proof, + &vk, + strategy, + 1 << circuit_settings.run_args.logrows, + ); let elapsed = now.elapsed(); info!( "verify took {}.{}", @@ -2063,7 +2087,7 @@ pub(crate) fn verify_aggr( let strategy = AccumulatorStrategy::new(params.verifier_params()); let vk = load_vk::, Fr, AggregationCircuit>(vk_path, ())?; let now = Instant::now(); - let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy); + let result = verify_proof_circuit_kzg(¶ms, proof, &vk, strategy, 1 << logrows); let elapsed = now.elapsed(); info!( diff --git a/src/graph/mod.rs b/src/graph/mod.rs index a35d69741..adf10c149 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -472,6 +472,14 @@ impl GraphSettings { instances } + /// calculate the log2 of the total number of instances + pub fn log2_total_instances(&self) -> u32 { + let sum = self.total_instances().iter().sum::(); + + // max between 1 and the log2 of the sums + std::cmp::max((sum as f64).log2().ceil() as u32, 1) + } + /// save params to file pub fn save(&self, path: &std::path::PathBuf) -> Result<(), std::io::Error> { let encoded = serde_json::to_string(&self)?; diff --git a/src/graph/model.rs b/src/graph/model.rs index fe978922e..9bb642024 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -611,7 +611,7 @@ impl Model { for (symbol, value) in run_args.variables.iter() { let symbol = model.symbol_table.sym(symbol); symbol_values = symbol_values.with(&symbol, *value as i64); - info!("set {} to {}", symbol, value); + debug!("set {} to {}", symbol, value); } // Note: do not optimize the model, as the layout will depend on underlying hardware @@ -1401,7 +1401,7 @@ impl Model { // Then number of columns in the circuits #[cfg(not(target_arch = "wasm32"))] - info!( + debug!( "{} {} {} (coord={}, constants={})", "model uses".blue(), region.row().to_string().blue(), diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index 0f86e868d..c33c4c455 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -571,6 +571,7 @@ where verifier_params, pk.get_vk(), strategy, + verifier_params.n(), )?; } let elapsed = now.elapsed(); @@ -658,6 +659,7 @@ pub fn verify_proof_circuit< params: &'params Scheme::ParamsVerifier, vk: &VerifyingKey, strategy: Strategy, + orig_n: u64, ) -> Result where Scheme::Scalar: SerdeObject @@ -678,7 +680,7 @@ where trace!("instances {:?}", instances); let mut transcript = TranscriptReadBuffer::init(Cursor::new(snark.proof.clone())); - verify_proof::(params, vk, strategy, instances, &mut transcript) + verify_proof::(params, vk, strategy, instances, &mut transcript, orig_n) } /// Loads a [VerifyingKey] at `path`. @@ -856,6 +858,7 @@ pub(crate) fn verify_proof_circuit_kzg< proof: Snark, vk: &VerifyingKey, strategy: Strategy, + orig_n: u64, ) -> Result { match proof.transcript_type { TranscriptType::EVM => verify_proof_circuit::< @@ -865,7 +868,7 @@ pub(crate) fn verify_proof_circuit_kzg< _, _, EvmTranscript, - >(&proof, params, vk, strategy), + >(&proof, params, vk, strategy, orig_n), TranscriptType::Poseidon => verify_proof_circuit::< Fr, VerifierSHPLONK<'_, Bn256>, @@ -873,7 +876,7 @@ pub(crate) fn verify_proof_circuit_kzg< _, _, PoseidonTranscript, - >(&proof, params, vk, strategy), + >(&proof, params, vk, strategy, orig_n), } } diff --git a/src/python.rs b/src/python.rs index 9c36bc352..f26f0b23f 100644 --- a/src/python.rs +++ b/src/python.rs @@ -689,14 +689,23 @@ fn prove( settings_path=PathBuf::from(DEFAULT_SETTINGS), vk_path=PathBuf::from(DEFAULT_VK), srs_path=None, + non_reduced_srs=Some(DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse::().unwrap()), ))] fn verify( proof_path: PathBuf, settings_path: PathBuf, vk_path: PathBuf, srs_path: Option, + non_reduced_srs: Option, ) -> Result { - crate::execute::verify(proof_path, settings_path, vk_path, srs_path).map_err(|e| { + crate::execute::verify( + proof_path, + settings_path, + vk_path, + srs_path, + non_reduced_srs, + ) + .map_err(|e| { let err_str = format!("Failed to run verify: {}", e); PyRuntimeError::new_err(err_str) })?; diff --git a/src/wasm.rs b/src/wasm.rs index 79a347b76..5b84bea3b 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -311,13 +311,19 @@ pub fn verify( let vk = VerifyingKey::::read::<_, GraphCircuit>( &mut reader, halo2_proofs::SerdeFormat::RawBytes, - circuit_settings, + circuit_settings.clone(), ) .map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?; let strategy = KZGSingleStrategy::new(params.verifier_params()); - let result = verify_proof_circuit_kzg(params.verifier_params(), snark, &vk, strategy); + let result = verify_proof_circuit_kzg( + params.verifier_params(), + snark, + &vk, + strategy, + 1 << circuit_settings.run_args.logrows, + ); match result { Ok(_) => Ok(true), @@ -387,15 +393,6 @@ pub fn prove( .into_bytes()) } -/// print hex representation of a proof -#[wasm_bindgen] -#[allow(non_snake_case)] -pub fn printProofHex(proof: wasm_bindgen::Clamped>) -> Result { - let proof: crate::pfsys::Snark = serde_json::from_slice(&proof[..]) - .map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?; - let hex_str = hex::encode(proof.proof); - Ok(format!("0x{}", hex_str)) -} // VALIDATION FUNCTIONS /// Witness file validation diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index d99bb4d95..6c71eda19 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1745,6 +1745,30 @@ mod native_tests { .status() .expect("failed to execute process"); assert!(status.success()); + + // load settings file + let settings = + std::fs::read_to_string(settings_path.clone()).expect("failed to read settings file"); + + let graph_settings = serde_json::from_str::(&settings) + .expect("failed to parse settings file"); + + // get_srs for the graph_settings_num_instances + let _ = download_srs(graph_settings.log2_total_instances()); + + let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) + .args([ + "verify", + format!("--settings-path={}", settings_path).as_str(), + "--proof-path", + &format!("{}/{}/proof.pf", test_dir, example_name), + "--vk-path", + &format!("{}/{}/key.vk", test_dir, example_name), + "--reduced-srs=true", + ]) + .status() + .expect("failed to execute process"); + assert!(status.success()); } // prove-serialize-verify, the usual full path diff --git a/tests/wasm.rs b/tests/wasm.rs index ecff4a718..23ea039d4 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -9,8 +9,8 @@ mod wasm32 { use ezkl::pfsys; use ezkl::wasm::{ bufferToVecOfstring, compiledCircuitValidation, encodeVerifierCalldata, genPk, genVk, - genWitness, inputValidation, pkValidation, poseidonHash, printProofHex, proofValidation, - prove, settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt, + genWitness, inputValidation, pkValidation, poseidonHash, proofValidation, prove, + settingsValidation, srsValidation, stringToFelt, stringToFloat, stringToInt, u8_array_to_u128_le, verify, vkValidation, witnessValidation, }; use halo2_solidity_verifier::encode_calldata;