From 6dd9aacb7f48121cf7cd8c68372f20ced00da312 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Fri, 7 Jul 2023 12:21:53 +0100 Subject: [PATCH] refactor: only `GraphData` can be on-chain (#342) --- benches/poseidon.rs | 2 +- examples/onnx/1l_gelu_noappx/input.json | 2 +- examples/onnx/1l_var/input.json | 2 +- src/circuit/ops/hybrid.rs | 2 +- src/commands.rs | 24 +- src/eth.rs | 13 +- src/execute.rs | 100 ++---- src/graph/input.rs | 343 ++---------------- src/graph/mod.rs | 458 ++++++++++++++++-------- src/graph/model.rs | 11 + tests/integration_tests.rs | 126 ++++--- tests/python/binding_tests.py | 4 +- tests/wasm/test.witness.json | 7 +- 13 files changed, 478 insertions(+), 616 deletions(-) diff --git a/benches/poseidon.rs b/benches/poseidon.rs index 2ef06e57a..d1d76b662 100644 --- a/benches/poseidon.rs +++ b/benches/poseidon.rs @@ -70,7 +70,7 @@ fn runposeidon(c: &mut Criterion) { PoseidonChip::::run(message.to_vec()) .unwrap(); - let mut image = Tensor::from(message.into_iter().map(|x| Value::known(x))); + let mut image = Tensor::from(message.into_iter().map(Value::known)); image.reshape(&[1, *size]); let circuit = MyCircuit { diff --git a/examples/onnx/1l_gelu_noappx/input.json b/examples/onnx/1l_gelu_noappx/input.json index 01b8377e0..bbb1b4e64 100644 --- a/examples/onnx/1l_gelu_noappx/input.json +++ b/examples/onnx/1l_gelu_noappx/input.json @@ -1 +1 @@ -{"input_data":[[0.61017877,0.21496391,0.8960367]],"input_shapes":[[3]],"output_data":[[0.44274902,0.12817383,0.73349]]} \ No newline at end of file +{"input_data":[[0.61017877,0.21496391,0.8960367]],"input_shapes":[[3]],"output_data":[[0.44140625,0.12890625,0.734375]]} \ No newline at end of file diff --git a/examples/onnx/1l_var/input.json b/examples/onnx/1l_var/input.json index 848d4b16e..0f54c6068 100644 --- a/examples/onnx/1l_var/input.json +++ b/examples/onnx/1l_var/input.json @@ -1 +1 @@ -{"input_data":[[0.048659664,0.040321846,0.092751384,0.058180947,0.019983828,0.096692465,0.07317094,0.06064367,0.052526843]],"output_data":[[0.00010251999,0.00026655197,0.00034856796]]} \ No newline at end of file +{"input_data":[[0.048659664,0.040321846,0.092751384,0.058180947,0.019983828,0.096692465,0.07317094,0.06064367,0.052526843]],"output_data":[[0.0,0.0,0.0]]} \ No newline at end of file diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 202689ded..87604d3c5 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -57,7 +57,7 @@ impl Op for HybridOp { HybridOp::Softmax { scales } => { tensor::ops::nonlinearities::multi_dim_softmax(&x, scales.0, scales.1) } - HybridOp::RangeCheck(..) => (x.clone(), vec![]), + HybridOp::RangeCheck(..) => (x, vec![]), }; // convert back to felt diff --git a/src/commands.rs b/src/commands.rs index 80aefdaae..dc61257fc 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -390,10 +390,10 @@ pub enum Commands { settings_path: Option, }, #[cfg(not(target_arch = "wasm32"))] - SetupTestEVMWitness { - /// The path to the .json witness file, which should include both the network input (possibly private) and the network output (public input to the proof) - #[arg(short = 'W', long)] - witness: PathBuf, + SetupTestEVMData { + /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'D', long)] + data: PathBuf, /// The path to the .onnx model file #[arg(short = 'M', long)] model: PathBuf, @@ -403,8 +403,8 @@ pub enum Commands { /// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information /// derived from the file information in the data .json file. /// Should include both the network input (possibly private) and the network output (public input to the proof) - #[arg(short = 'D', long)] - test_witness: PathBuf, + #[arg(short = 'T', long)] + test_data: PathBuf, /// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state #[arg(short = 'U', long)] rpc_url: Option, @@ -512,13 +512,13 @@ pub enum Commands { /// If not set will just use the default unoptimized SOLC configuration. #[arg(long)] optimizer_runs: Option, - /// The path to the .json witness file, which should + /// The path to the .json data file, which should /// contain the necessary calldata and accoount addresses /// needed need to read from all the on-chain /// view functions that return the data that the network /// ingests as inputs. - #[arg(short = 'W', long)] - witness: PathBuf, + #[arg(short = 'D', long)] + data: PathBuf, // todo, optionally allow supplying proving key }, @@ -595,9 +595,9 @@ pub enum Commands { #[cfg(not(target_arch = "wasm32"))] #[command(name = "deploy-evm-da-verifier", arg_required_else_help = true)] DeployEvmDataAttestationVerifier { - /// The path to the .json witness file, which should include both the network input (possibly private) and the network output (public input to the proof) - #[arg(short = 'W', long)] - witness: PathBuf, + /// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof) + #[arg(short = 'D', long)] + data: PathBuf, /// The path to load circuit params from #[arg(long)] settings_path: PathBuf, diff --git a/src/eth.rs b/src/eth.rs index ad33f654b..76294295c 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -1,4 +1,5 @@ -use crate::graph::input::{CallsToAccount, GraphWitness, WitnessSource}; +use crate::graph::input::{CallsToAccount, GraphData}; +use crate::graph::DataSource; #[cfg(not(target_arch = "wasm32"))] use crate::graph::GraphSettings; use crate::pfsys::evm::{DeploymentCode, EvmVerificationError}; @@ -145,13 +146,13 @@ pub async fn deploy_verifier_via_solidity( /// pub async fn deploy_da_verifier_via_solidity( settings_path: PathBuf, - witness: PathBuf, + input: PathBuf, sol_code_path: PathBuf, rpc_url: Option<&str>, ) -> Result> { let (_, client) = setup_eth_backend(rpc_url).await?; - let witness = GraphWitness::from_path(witness)?; + let input = GraphData::from_path(input)?; let settings = GraphSettings::load(&settings_path)?; @@ -164,12 +165,12 @@ pub async fn deploy_da_verifier_via_solidity( let mut instance_idx = 0; let mut contract_instance_offset = 0; - if let WitnessSource::OnChain(source) = witness.input_data { + if let DataSource::OnChain(source) = input.input_data { for call in source.calls { calls_to_accounts.push(call); instance_idx += 1; } - } else if let WitnessSource::File(source) = witness.input_data { + } else if let DataSource::File(source) = input.input_data { if settings.run_args.input_visibility.is_public() { instance_idx += source.len(); for s in source { @@ -178,7 +179,7 @@ pub async fn deploy_da_verifier_via_solidity( } } - if let WitnessSource::OnChain(source) = witness.output_data { + if let Some(DataSource::OnChain(source)) = input.output_data { let output_scales = settings.model_output_scales; for call in source.calls { calls_to_accounts.push(call); diff --git a/src/execute.rs b/src/execute.rs index 50ca14ab1..d3268cedf 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -6,7 +6,7 @@ use crate::commands::{Cli, Commands, RunArgs}; use crate::eth::{deploy_da_verifier_via_solidity, deploy_verifier_via_solidity}; #[cfg(not(target_arch = "wasm32"))] use crate::eth::{fix_verifier_sol, get_contract_artifacts, verify_proof_via_solidity}; -use crate::graph::input::{GraphInput, WitnessSource}; +use crate::graph::input::GraphData; use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model}; #[cfg(not(target_arch = "wasm32"))] use crate::graph::{TestDataSource, TestSources, Visibility}; @@ -41,6 +41,7 @@ use halo2curves::ff::Field; #[cfg(not(target_arch = "wasm32"))] use indicatif::{ProgressBar, ProgressStyle}; use instant::Instant; +#[cfg(not(target_arch = "wasm32"))] use itertools::Itertools; #[cfg(not(target_arch = "wasm32"))] use log::debug; @@ -161,13 +162,13 @@ pub async fn run(cli: Cli) -> Result<(), Box> { sol_code_path, sol_bytecode_path, optimizer_runs, - witness, + data, } => create_evm_data_attestation_verifier( vk_path, srs_path, settings_path, sol_code_path, - witness, + data, sol_bytecode_path, optimizer_runs, ), @@ -195,20 +196,20 @@ pub async fn run(cli: Cli) -> Result<(), Box> { pk_path, } => setup(model, srs_path, settings_path, vk_path, pk_path), #[cfg(not(target_arch = "wasm32"))] - Commands::SetupTestEVMWitness { - witness, + Commands::SetupTestEVMData { + data, model, settings_path, - test_witness, + test_data, rpc_url, input_source, output_source, } => { setup_test_evm_witness( - witness, + data, model, settings_path, - test_witness, + test_data, rpc_url, input_source, output_source, @@ -276,12 +277,12 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } => deploy_evm(sol_code_path, rpc_url, addr_path).await, #[cfg(not(target_arch = "wasm32"))] Commands::DeployEvmDataAttestationVerifier { - witness, + data, settings_path, sol_code_path, rpc_url, addr_path, - } => deploy_da_evm(witness, settings_path, sol_code_path, rpc_url, addr_path).await, + } => deploy_da_evm(data, settings_path, sol_code_path, rpc_url, addr_path).await, #[cfg(not(target_arch = "wasm32"))] Commands::VerifyEVM { proof_path, @@ -489,16 +490,16 @@ pub(crate) async fn gen_witness( let mut circuit = GraphCircuit::from_settings(&circuit_settings, &model_path, CheckMode::UNSAFE)?; - let data = GraphInput::from_path(data)?; + let data = GraphData::from_path(data)?; #[cfg(not(target_arch = "wasm32"))] - circuit.load_graph_input(&data).await?; + let input = circuit.load_graph_input(&data).await?; #[cfg(target_arch = "wasm32")] - circuit.load_graph_input(&data)?; + let input = circuit.load_graph_input(&data)?; let start_time = Instant::now(); - let res = circuit.forward()?; + let witness = circuit.forward(&input)?; trace!( "witness generation (B={:?}) took {:?}", @@ -506,32 +507,6 @@ pub(crate) async fn gen_witness( start_time.elapsed() ); - trace!( - "model forward pass output shapes: {:?}", - res.outputs.iter().map(|t| t.dims()).collect_vec() - ); - - let input_witness: Vec> = res - .inputs - .iter() - .map(|t| t.clone().into_iter().collect_vec()) - .collect(); - - let output_witness: Vec> = res - .outputs - .iter() - .map(|t| t.clone().into_iter().collect_vec()) - .collect(); - trace!("model forward pass output: {:?}", output_witness); - - let witness = GraphWitness { - input_data: WitnessSource::File(input_witness), - output_data: WitnessSource::File(output_witness), - processed_inputs: res.processed_inputs, - processed_params: res.processed_params, - processed_outputs: res.processed_outputs, - }; - if let Some(output_path) = output { serde_json::to_writer(&File::create(output_path)?, &witness)?; } @@ -596,7 +571,7 @@ pub(crate) async fn calibrate( settings_path: PathBuf, target: CalibrationTarget, ) -> Result<(), Box> { - let data = GraphInput::from_path(data)?; + let data = GraphData::from_path(data)?; // load the pre-generated settings let settings = GraphSettings::load(&settings_path)?; // now retrieve the run args @@ -647,7 +622,7 @@ pub(crate) async fn calibrate( .unwrap(); tokio::task::spawn(async move { - circuit + let data = circuit .load_graph_input(&chunk) .await .map_err(|_| "failed to load circuit inputs") @@ -657,7 +632,9 @@ pub(crate) async fn calibrate( // // ensures we have converged let params_before = circuit.settings.clone(); - circuit.calibrate().map_err(|_| "failed to calibrate")?; + circuit + .calibrate(&data) + .map_err(|_| "failed to calibrate")?; let params_after = circuit.settings.clone(); if params_before == params_after { break; @@ -754,11 +731,8 @@ pub(crate) async fn mock( let circuit_settings = GraphSettings::load(&settings_path)?; let mut circuit = GraphCircuit::from_settings(&circuit_settings, &model_path, CheckMode::SAFE)?; - let data = GraphWitness::from_path(data_path.clone())?; + let data = GraphWitness::from_path(data_path)?; - #[cfg(not(target_arch = "wasm32"))] - circuit.load_graph_witness(&data, None).await?; - #[cfg(target_arch = "wasm32")] circuit.load_graph_witness(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; @@ -873,11 +847,11 @@ pub(crate) fn create_evm_data_attestation_verifier( srs_path: PathBuf, settings_path: PathBuf, sol_code_path: PathBuf, - witness: PathBuf, + input: PathBuf, sol_bytecode_path: Option, runs: Option, ) -> Result<(), Box> { - use crate::graph::VarVisibility; + use crate::graph::{DataSource, VarVisibility}; let settings = GraphSettings::load(&settings_path)?; let params = load_params_cmd(srs_path, settings.run_args.logrows)?; @@ -894,9 +868,9 @@ pub(crate) fn create_evm_data_attestation_verifier( let mut f = File::create(sol_code_path.clone())?; let _ = f.write(yul_code.as_bytes()); - let data = GraphWitness::from_path(witness)?; + let data = GraphData::from_path(input)?; - let output_data = if let WitnessSource::OnChain(source) = data.output_data { + let output_data = if let Some(DataSource::OnChain(source)) = data.output_data { if !visibility.output.is_public() { todo!("we currently don't support private output data on chain") } @@ -909,7 +883,7 @@ pub(crate) fn create_evm_data_attestation_verifier( None }; - let input_data = if let WitnessSource::OnChain(source) = data.input_data { + let input_data = if let DataSource::OnChain(source) = data.input_data { if !visibility.input.is_public() { todo!("we currently don't support private input data on chain") } @@ -941,14 +915,14 @@ pub(crate) fn create_evm_data_attestation_verifier( #[cfg(not(target_arch = "wasm32"))] pub(crate) async fn deploy_da_evm( - witness: PathBuf, + data: PathBuf, settings_path: PathBuf, sol_code_path: PathBuf, rpc_url: Option, addr_path: PathBuf, ) -> Result<(), Box> { let contract_address = - deploy_da_verifier_via_solidity(settings_path, witness, sol_code_path, rpc_url.as_deref()) + deploy_da_verifier_via_solidity(settings_path, data, sol_code_path, rpc_url.as_deref()) .await?; info!("Contract deployed at: {}", contract_address); @@ -1063,7 +1037,7 @@ pub(crate) async fn setup_test_evm_witness( data_path: PathBuf, model_path: PathBuf, settings_path: PathBuf, - test_witness: PathBuf, + test_data: PathBuf, rpc_url: Option, input_source: TestDataSource, output_source: TestDataSource, @@ -1071,7 +1045,7 @@ pub(crate) async fn setup_test_evm_witness( use crate::graph::TestOnChainData; info!("run this command in background to keep the instance running for testing"); - let data = GraphWitness::from_path(data_path)?; + let mut data = GraphData::from_path(data_path)?; let circuit_settings = GraphSettings::load(&settings_path)?; let mut circuit = GraphCircuit::from_settings(&circuit_settings, &model_path, CheckMode::SAFE)?; @@ -1081,8 +1055,8 @@ pub(crate) async fn setup_test_evm_witness( return Err("Both input and output cannot be from files".into()); } - let test_on_chain_witness = TestOnChainData { - data: test_witness.clone(), + let test_on_chain_data = TestOnChainData { + data: test_data.clone(), rpc: rpc_url, data_sources: TestSources { input: input_source, @@ -1091,7 +1065,7 @@ pub(crate) async fn setup_test_evm_witness( }; circuit - .load_graph_witness(&data, Some(test_on_chain_witness)) + .populate_on_chain_test_data(&mut data, test_on_chain_data) .await?; Ok(()) @@ -1114,7 +1088,7 @@ pub(crate) async fn prove( let circuit_settings = GraphSettings::load(&settings_path)?; let mut circuit = GraphCircuit::from_settings(&circuit_settings, &model_path, check_mode)?; - circuit.load_graph_witness(&data, None).await?; + circuit.load_graph_witness(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; let circuit_settings = circuit.settings.clone(); @@ -1186,7 +1160,7 @@ pub(crate) async fn fuzz( let _r = Gag::stdout().unwrap(); let params = gen_srs::>(logrows); - let data = GraphWitness::from_path(data_path.clone())?; + let data = GraphWitness::from_path(data_path)?; // these aren't real values so the sanity checks are mostly meaningless let mut circuit = match settings_path { Some(path) => { @@ -1199,7 +1173,7 @@ pub(crate) async fn fuzz( let pk = create_keys::, Fr, GraphCircuit>(&circuit, ¶ms) .map_err(Box::::from)?; - circuit.load_graph_witness(&data, None).await?; + circuit.load_graph_witness(&data)?; let public_inputs = circuit.prepare_public_inputs(&data)?; let strategy = KZGSingleStrategy::new(¶ms); @@ -1456,7 +1430,7 @@ pub(crate) fn aggregate( check_mode: CheckMode, ) -> Result<(), Box> { // the K used for the aggregation circuit - let params = load_params_cmd(srs_path.clone(), logrows)?; + let params = load_params_cmd(srs_path, logrows)?; let mut snarks = vec![]; for proof_path in aggregation_snarks.iter() { diff --git a/src/graph/input.rs b/src/graph/input.rs index 3e96cff1d..2599e5d8b 100644 --- a/src/graph/input.rs +++ b/src/graph/input.rs @@ -18,17 +18,12 @@ use std::io::Read; // use std::collections::HashMap; use super::quantize_float; -use super::{modules::ModuleForwardResult, GraphError}; +use super::GraphError; type Decimals = u8; type Call = String; type RPCUrl = String; -#[cfg(feature = "python-bindings")] -use crate::pfsys::field_to_vecu64_montgomery; -#[cfg(feature = "python-bindings")] -use halo2curves::bn256::G1Affine; - /// #[derive(Clone, Debug, PartialOrd, PartialEq)] pub enum FileSourceInner { @@ -95,11 +90,15 @@ impl FileSourceInner { FileSourceInner::Field(f) => *f, } } + /// Convert to a float + pub fn to_float(&self) -> f64 { + match self { + FileSourceInner::Float(f) => *f, + FileSourceInner::Field(f) => crate::fieldutils::felt_to_i128(*f) as f64, + } + } } -/// Inner elements of witness coming from a witness -pub type WitnessFileSource = Vec>; - /// Inner elements of inputs/outputs coming from on-chain #[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)] pub struct OnChainSource { @@ -120,7 +119,7 @@ impl OnChainSource { #[cfg(not(target_arch = "wasm32"))] /// Create dummy local on-chain data to test the OnChain data source pub async fn test_from_file_data( - data: &WitnessFileSource, + data: &FileSource, scales: Vec, shapes: Vec>, rpc: Option<&str>, @@ -140,12 +139,7 @@ impl OnChainSource { // unquantize data let float_data = data .iter() - .zip(scales.iter()) - .map(|(t, scale)| { - t.iter() - .map(|e| ((crate::fieldutils::felt_to_i128(*e) as f64 / scale) as f32)) - .collect_vec() - }) + .map(|t| t.iter().map(|e| (e.to_float() as f32)).collect_vec()) .collect::>>(); let calls_to_accounts = test_on_chain_data(client.clone(), &float_data).await?; @@ -181,7 +175,7 @@ impl OnChainSource { let used_rpc = rpc.unwrap_or(&anvil.endpoint()).to_string(); - // Fill the input_data field of the GraphInput struct + // Fill the input_data field of the GraphData struct Ok(( inputs, OnChainSource::new(calls_to_accounts.clone(), used_rpc), @@ -274,139 +268,23 @@ impl<'de> Deserialize<'de> for DataSource { } } -/// Enum that defines source of the inputs/outputs to the EZKL model -#[derive(Clone, Debug, PartialOrd, PartialEq)] -pub enum WitnessSource { - /// .json File data source. - File(WitnessFileSource), - /// On-chain data source. The first element is the calls to the account, and the second is the RPC url. - OnChain(OnChainSource), -} -impl Default for WitnessSource { - fn default() -> Self { - WitnessSource::File(vec![vec![]]) - } -} - -impl From for WitnessSource { - fn from(data: WitnessFileSource) -> Self { - WitnessSource::File(data) - } -} - -impl From for WitnessSource { - fn from(data: OnChainSource) -> Self { - WitnessSource::OnChain(data) - } -} - -// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT -// UNTAGGED ENUMS WONT WORK :( as highlighted here: -impl<'de> Deserialize<'de> for WitnessSource { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let this_json: Box = Deserialize::deserialize(deserializer)?; - - let first_try: Result>, _> = serde_json::from_str(this_json.get()); - - if let Ok(t) = first_try { - let t: Vec> = t - .iter() - .map(|x| x.iter().map(|fp| Fp::from_raw(*fp)).collect()) - .collect(); - return Ok(WitnessSource::File(t)); - } - - let second_try: Result = serde_json::from_str(this_json.get()); - if let Ok(t) = second_try { - return Ok(WitnessSource::OnChain(t)); - } - - Err(serde::de::Error::custom( - "failed to deserialize WitnessSource", - )) - } -} - -impl Serialize for WitnessSource { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - WitnessSource::File(source) => { - let field_elems: Vec> = source - .iter() - .map(|x| x.iter().map(|fp| field_to_vecu64(fp)).collect()) - .collect::>(); - field_elems.serialize(serializer) - } - WitnessSource::OnChain(source) => { - // leave it untagged - let mut state = serializer.serialize_struct("", 2)?; - state.serialize_field("rpc", &source.rpc)?; - state.serialize_field("calls", &source.calls)?; - state.end() - } - } - } -} - -/// The input tensor data and shape, and output data for the computational graph (model) as floats. -/// For example, the input might be the image data for a neural network, and the output class scores. -#[derive(Clone, Debug, Deserialize, Default)] -pub struct GraphWitness { - /// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain). - /// TODO: Add retrieve from on-chain functionality - pub input_data: WitnessSource, - /// The expected output of the model (can be empty vectors if outputs are not being constrained). - pub output_data: WitnessSource, - /// Optional hashes of the inputs (can be None if there are no commitments). Wrapped as Option for backwards compatibility - pub processed_inputs: Option, - /// Optional hashes of the params (can be None if there are no commitments). Wrapped as Option for backwards compatibility - pub processed_params: Option, - /// Optional hashes of the outputs (can be None if there are no commitments). Wrapped as Option for backwards compatibility - pub processed_outputs: Option, -} - -impl GraphWitness { - /// - pub fn new(input_data: WitnessSource, output_data: WitnessSource) -> Self { - GraphWitness { - input_data, - output_data, - processed_inputs: None, - processed_params: None, - processed_outputs: None, - } - } - /// Load the model input from a file - pub fn from_path(path: std::path::PathBuf) -> Result> { - let mut file = std::fs::File::open(path)?; - let mut data = String::new(); - file.read_to_string(&mut data)?; - serde_json::from_str(&data).map_err(|e| e.into()) - } - - /// Save the model input to a file - pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box> { - serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into()) - } -} /// Input to graph as a datasource -/// Always use JSON serialization for GraphInput. Seriously. +/// Always use JSON serialization for GraphData. Seriously. #[derive(Clone, Debug, Deserialize, Default, PartialEq)] -pub struct GraphInput { +pub struct GraphData { /// Inputs to the model / computational graph (can be empty vectors if inputs are coming from on-chain). pub input_data: DataSource, + /// Outputs of the model / computational graph (can be empty vectors if outputs are coming from on-chain). + pub output_data: Option, } -impl GraphInput { +impl GraphData { /// pub fn new(input_data: DataSource) -> Self { - GraphInput { input_data } + GraphData { + input_data, + output_data: None, + } } /// Load the model input from a file @@ -432,8 +310,9 @@ impl GraphInput { let mut batched_inputs = vec![]; let iterable = match self { - GraphInput { + GraphData { input_data: DataSource::File(data), + output_data: _, } => data, _ => { todo!("on-chain data batching not implemented yet") @@ -474,79 +353,13 @@ impl GraphInput { // create a new GraphWitness for each batch let batches = input_batches .into_iter() - .map(GraphInput::new) - .collect::>(); + .map(GraphData::new) + .collect::>(); Ok(batches) } } -#[cfg(feature = "python-bindings")] -fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec) { - let poseidon_hash: Vec<[u64; 4]> = poseidon_hash - .iter() - .map(field_to_vecu64_montgomery) - .collect(); - pydict.set_item("poseidon_hash", poseidon_hash).unwrap(); -} - -#[cfg(feature = "python-bindings")] -fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) { - let g1affine_x = field_to_vecu64_montgomery(&g1affine.x); - let g1affine_y = field_to_vecu64_montgomery(&g1affine.y); - g1affine_dict.set_item("x", g1affine_x).unwrap(); - g1affine_dict.set_item("y", g1affine_y).unwrap(); -} - -#[cfg(feature = "python-bindings")] -use super::modules::ElGamalResult; -#[cfg(feature = "python-bindings")] -fn insert_elgamal_results_pydict(py: Python, pydict: &PyDict, elgamal_results: &ElGamalResult) { - let results_dict = PyDict::new(py); - let cipher_text: Vec> = elgamal_results - .ciphertexts - .iter() - .map(|v| { - v.iter() - .map(field_to_vecu64_montgomery) - .collect::>() - }) - .collect::>>(); - results_dict.set_item("ciphertexts", cipher_text).unwrap(); - - let variables_dict = PyDict::new(py); - let variables = &elgamal_results.variables; - - let r = field_to_vecu64_montgomery(&variables.r); - variables_dict.set_item("r", r).unwrap(); - // elgamal secret key - let sk = field_to_vecu64_montgomery(&variables.sk); - variables_dict.set_item("sk", sk).unwrap(); - - let pk_dict = PyDict::new(py); - // elgamal public key - g1affine_to_pydict(pk_dict, &variables.pk); - variables_dict.set_item("pk", pk_dict).unwrap(); - - let aux_generator_dict = PyDict::new(py); - // elgamal aux generator used in ecc chip - g1affine_to_pydict(aux_generator_dict, &variables.aux_generator); - variables_dict - .set_item("aux_generator", aux_generator_dict) - .unwrap(); - - // elgamal window size used in ecc chip - variables_dict - .set_item("window_size", variables.window_size) - .unwrap(); - - results_dict.set_item("variables", variables_dict).unwrap(); - - pydict.set_item("elgamal", results_dict).unwrap(); - - //elgamal -} - #[cfg(feature = "python-bindings")] impl ToPyObject for CallsToAccount { fn to_object(&self, py: Python) -> PyObject { @@ -582,111 +395,14 @@ impl ToPyObject for FileSourceInner { } } -#[cfg(feature = "python-bindings")] -impl ToPyObject for WitnessSource { - fn to_object(&self, py: Python) -> PyObject { - match self { - WitnessSource::File(data) => { - let field_elem: Vec> = data - .iter() - .map(|x| x.iter().map(field_to_vecu64).collect()) - .collect(); - field_elem.to_object(py) - } - WitnessSource::OnChain(source) => { - let dict = PyDict::new(py); - dict.set_item("rpc_url", &source.rpc).unwrap(); - dict.set_item("calls_to_accounts", &source.calls).unwrap(); - dict.to_object(py) - } - } - } -} - -#[cfg(feature = "python-bindings")] -impl ToPyObject for GraphWitness { - fn to_object(&self, py: Python) -> PyObject { - // Create a Python dictionary - let dict = PyDict::new(py); - let dict_inputs = PyDict::new(py); - let dict_params = PyDict::new(py); - let dict_outputs = PyDict::new(py); - - let input_data_mut = &self.input_data; - let output_data_mut = &self.output_data; - - dict.set_item("input_data", &input_data_mut).unwrap(); - dict.set_item("output_data", &output_data_mut).unwrap(); - - if let Some(processed_inputs) = &self.processed_inputs { - //poseidon_hash - if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash { - insert_poseidon_hash_pydict(&dict_inputs, processed_inputs_poseidon_hash); - } - if let Some(processed_inputs_elgamal) = &processed_inputs.elgamal { - insert_elgamal_results_pydict(py, dict_inputs, processed_inputs_elgamal); - } - - dict.set_item("processed_inputs", dict_inputs).unwrap(); - } - - if let Some(processed_params) = &self.processed_params { - if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash { - insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash); - } - if let Some(processed_params_elgamal) = &processed_params.elgamal { - insert_elgamal_results_pydict(py, dict_params, processed_params_elgamal); - } - - dict.set_item("processed_params", dict_params).unwrap(); - } - - if let Some(processed_outputs) = &self.processed_outputs { - if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash { - insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash); - } - if let Some(processed_outputs_elgamal) = &processed_outputs.elgamal { - insert_elgamal_results_pydict(py, dict_outputs, processed_outputs_elgamal); - } - - dict.set_item("processed_outputs", dict_outputs).unwrap(); - } - - dict.to_object(py) - } -} - -impl Serialize for GraphInput { +impl Serialize for GraphData { fn serialize(&self, serializer: S) -> Result where S: Serializer, { - let mut state = serializer.serialize_struct("GraphInput", 4)?; - state.serialize_field("input_data", &self.input_data)?; - state.end() - } -} - -impl Serialize for GraphWitness { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut state = serializer.serialize_struct("GraphWitness", 4)?; + let mut state = serializer.serialize_struct("GraphData", 4)?; state.serialize_field("input_data", &self.input_data)?; state.serialize_field("output_data", &self.output_data)?; - - if let Some(processed_inputs) = &self.processed_inputs { - state.serialize_field("processed_inputs", &processed_inputs)?; - } - - if let Some(processed_params) = &self.processed_params { - state.serialize_field("processed_params", &processed_params)?; - } - - if let Some(processed_outputs) = &self.processed_outputs { - state.serialize_field("processed_outputs", &processed_outputs)?; - } state.end() } } @@ -716,7 +432,7 @@ mod tests { #[test] // this is for backwards compatibility with the old format fn test_graph_input_serialization_round_trip() { - let file = GraphInput::new(DataSource::from(vec![vec![ + let file = GraphData::new(DataSource::from(vec![vec![ 0.05326242372393608, 0.07497056573629379, 0.05235547572374344, @@ -724,12 +440,11 @@ mod tests { let serialized = serde_json::to_string(&file).unwrap(); - const JSON: &str = - r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]]}"#; + const JSON: &str = r#"{"input_data":[[0.05326242372393608,0.07497056573629379,0.05235547572374344]],"output_data":null}"#; assert_eq!(serialized, JSON); - let graph_input3 = serde_json::from_str::(JSON) + let graph_input3 = serde_json::from_str::(JSON) .map_err(|e| e.to_string()) .unwrap(); assert_eq!(graph_input3, file); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 30a2ccd40..c384c9c10 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -12,16 +12,21 @@ pub mod utilities; pub mod vars; use halo2_proofs::circuit::Value; -pub use input::{DataSource, GraphWitness, WitnessSource}; +pub use input::DataSource; +use itertools::Itertools; #[cfg(not(target_arch = "wasm32"))] use self::input::OnChainSource; -use self::input::{FileSource, GraphInput, WitnessFileSource}; +use self::input::{FileSource, GraphData}; +use self::modules::{ + GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSettings, ModuleSizes, +}; use crate::circuit::lookup::LookupOp; use crate::circuit::modules::ModulePlanner; use crate::circuit::CheckMode; use crate::commands::RunArgs; use crate::graph::modules::ModuleInstanceOffset; +use crate::pfsys::field_to_vecu64; use crate::tensor::{Tensor, ValTensor}; use halo2_proofs::{ circuit::Layouter, @@ -32,17 +37,20 @@ use halo2curves::ff::PrimeField; use log::{error, info, trace}; pub use model::*; pub use node::*; +#[cfg(feature = "python-bindings")] +use pyo3::prelude::*; +#[cfg(feature = "python-bindings")] +use pyo3::types::PyDict; +#[cfg(feature = "python-bindings")] +use pyo3::ToPyObject; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::io::{Read, Write}; +use std::ops::Deref; use thiserror::Error; pub use utilities::*; pub use vars::*; -use self::modules::{ - GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSettings, ModuleSizes, -}; - /// circuit related errors. #[derive(Debug, Error)] pub enum GraphError { @@ -90,6 +98,50 @@ pub enum GraphError { PackingExponent, } +/// Inner elements of witness coming from a witness +#[derive(Clone, Debug, Default)] +pub struct WitnessFileSource(Vec>); + +impl From>> for WitnessFileSource { + fn from(value: Vec>) -> Self { + WitnessFileSource(value) + } +} + +// !!! ALWAYS USE JSON SERIALIZATION FOR GRAPH INPUT +// UNTAGGED ENUMS WONT WORK :( as highlighted here: +impl<'de> Deserialize<'de> for WitnessFileSource { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let this_json: Box = Deserialize::deserialize(deserializer)?; + + let t: Vec> = serde_json::from_str(this_json.get()) + .map_err(|_| serde::de::Error::custom("failed to deserialize WitnessSource"))?; + + let t: Vec> = t + .iter() + .map(|x| x.iter().map(|fp| Fp::from_raw(*fp)).collect()) + .collect(); + Ok(WitnessFileSource(t)) + } +} + +impl Serialize for WitnessFileSource { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let field_elems: Vec> = self + .0 + .iter() + .map(|x| x.iter().map(field_to_vecu64).collect()) + .collect::>(); + field_elems.serialize(serializer) + } +} + const ASSUMED_BLINDING_FACTORS: usize = 6; /// 26 @@ -97,11 +149,11 @@ const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2; /// Result from a forward pass #[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct ForwardResult { +pub struct GraphWitness { /// The inputs of the forward pass - pub inputs: Vec>, + pub inputs: WitnessFileSource, /// The output of the forward pass - pub outputs: Vec>, + pub outputs: WitnessFileSource, /// Any hashes of inputs generated during the forward pass pub processed_inputs: Option, /// Any hashes of params generated during the forward pass @@ -109,7 +161,179 @@ pub struct ForwardResult { /// Any hashes of outputs generated during the forward pass pub processed_outputs: Option, /// max lookup input - pub max_lookup_input: i128, + pub max_lookup_inputs: i128, +} + +impl GraphWitness { + /// + pub fn new(inputs: Vec>, outputs: Vec>) -> Self { + GraphWitness { + inputs: inputs.into(), + outputs: outputs.into(), + processed_inputs: None, + processed_params: None, + processed_outputs: None, + max_lookup_inputs: 0, + } + } + /// Load the model input from a file + pub fn from_path(path: std::path::PathBuf) -> Result> { + let mut file = std::fs::File::open(path)?; + let mut data = String::new(); + file.read_to_string(&mut data)?; + serde_json::from_str(&data).map_err(|e| e.into()) + } + + /// Save the model input to a file + pub fn save(&self, path: std::path::PathBuf) -> Result<(), Box> { + serde_json::to_writer(std::fs::File::create(path)?, &self).map_err(|e| e.into()) + } + + /// + pub fn get_input_tensor(&self) -> Vec> { + self.inputs + .0 + .clone() + .into_iter() + .map(|i| Tensor::from(i.into_iter())) + .collect::>>() + } +} + +#[cfg(feature = "python-bindings")] +impl ToPyObject for GraphWitness { + fn to_object(&self, py: Python) -> PyObject { + // Create a Python dictionary + let dict = PyDict::new(py); + let dict_inputs = PyDict::new(py); + let dict_params = PyDict::new(py); + let dict_outputs = PyDict::new(py); + + let inputs: Vec> = self + .inputs + .0 + .iter() + .map(|x| x.iter().map(field_to_vecu64).collect()) + .collect(); + + let outputs: Vec> = self + .outputs + .0 + .iter() + .map(|x| x.iter().map(field_to_vecu64).collect()) + .collect(); + + dict.set_item("inputs", &inputs).unwrap(); + dict.set_item("outputs", &outputs).unwrap(); + dict.set_item("max_lookup_inputs", &self.max_lookup_inputs) + .unwrap(); + + if let Some(processed_inputs) = &self.processed_inputs { + //poseidon_hash + if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash { + insert_poseidon_hash_pydict(&dict_inputs, processed_inputs_poseidon_hash); + } + if let Some(processed_inputs_elgamal) = &processed_inputs.elgamal { + insert_elgamal_results_pydict(py, dict_inputs, processed_inputs_elgamal); + } + + dict.set_item("processed_inputs", dict_inputs).unwrap(); + } + + if let Some(processed_params) = &self.processed_params { + if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash { + insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash); + } + if let Some(processed_params_elgamal) = &processed_params.elgamal { + insert_elgamal_results_pydict(py, dict_params, processed_params_elgamal); + } + + dict.set_item("processed_params", dict_params).unwrap(); + } + + if let Some(processed_outputs) = &self.processed_outputs { + if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash { + insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash); + } + if let Some(processed_outputs_elgamal) = &processed_outputs.elgamal { + insert_elgamal_results_pydict(py, dict_outputs, processed_outputs_elgamal); + } + + dict.set_item("processed_outputs", dict_outputs).unwrap(); + } + + dict.to_object(py) + } +} + +#[cfg(feature = "python-bindings")] +fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec) { + let poseidon_hash: Vec<[u64; 4]> = poseidon_hash + .iter() + .map(field_to_vecu64_montgomery) + .collect(); + pydict.set_item("poseidon_hash", poseidon_hash).unwrap(); +} + +#[cfg(feature = "python-bindings")] +use crate::pfsys::field_to_vecu64_montgomery; +#[cfg(feature = "python-bindings")] +use halo2curves::bn256::G1Affine; +#[cfg(feature = "python-bindings")] +fn g1affine_to_pydict(g1affine_dict: &PyDict, g1affine: &G1Affine) { + let g1affine_x = field_to_vecu64_montgomery(&g1affine.x); + let g1affine_y = field_to_vecu64_montgomery(&g1affine.y); + g1affine_dict.set_item("x", g1affine_x).unwrap(); + g1affine_dict.set_item("y", g1affine_y).unwrap(); +} + +#[cfg(feature = "python-bindings")] +use modules::ElGamalResult; +#[cfg(feature = "python-bindings")] +fn insert_elgamal_results_pydict(py: Python, pydict: &PyDict, elgamal_results: &ElGamalResult) { + let results_dict = PyDict::new(py); + let cipher_text: Vec> = elgamal_results + .ciphertexts + .iter() + .map(|v| { + v.iter() + .map(field_to_vecu64_montgomery) + .collect::>() + }) + .collect::>>(); + results_dict.set_item("ciphertexts", cipher_text).unwrap(); + + let variables_dict = PyDict::new(py); + let variables = &elgamal_results.variables; + + let r = field_to_vecu64_montgomery(&variables.r); + variables_dict.set_item("r", r).unwrap(); + // elgamal secret key + let sk = field_to_vecu64_montgomery(&variables.sk); + variables_dict.set_item("sk", sk).unwrap(); + + let pk_dict = PyDict::new(py); + // elgamal public key + g1affine_to_pydict(pk_dict, &variables.pk); + variables_dict.set_item("pk", pk_dict).unwrap(); + + let aux_generator_dict = PyDict::new(py); + // elgamal aux generator used in ecc chip + g1affine_to_pydict(aux_generator_dict, &variables.aux_generator); + variables_dict + .set_item("aux_generator", aux_generator_dict) + .unwrap(); + + // elgamal window size used in ecc chip + variables_dict + .set_item("window_size", variables.window_size) + .unwrap(); + + results_dict.set_item("variables", variables_dict).unwrap(); + + pydict.set_item("elgamal", results_dict).unwrap(); + + //elgamal } /// model parameters @@ -173,9 +397,7 @@ pub struct GraphCircuit { /// The model / graph of computations. pub model: Model, /// Vector of input tensors to the model / graph of computations. - pub inputs: Vec>, - /// Vector of input tensors to the model / graph of computations. - pub outputs: Vec>, + pub graph_witness: GraphWitness, /// The settings of the model / graph of computations. pub settings: GraphSettings, /// The settings of the model's modules. @@ -229,14 +451,14 @@ impl GraphCircuit { run_args: RunArgs, check_mode: CheckMode, ) -> Result> { - // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards - let mut inputs: Vec> = vec![]; + // // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards + let mut inputs: Vec> = vec![]; for shape in model.graph.input_shapes() { - let t: Tensor = Tensor::new(None, &shape).unwrap(); + let t: Vec = vec![Fp::zero(); shape.iter().product::()]; inputs.push(t); } - // dummy module settings, must load from GraphInput after + // dummy module settings, must load from GraphData after let module_settings = ModuleSettings::default(); let mut settings = model.gen_params(run_args, check_mode)?; @@ -263,8 +485,7 @@ impl GraphCircuit { Ok(GraphCircuit { model, - inputs, - outputs: vec![], + graph_witness: GraphWitness::new(inputs, vec![]), settings, module_settings, }) @@ -277,78 +498,37 @@ impl GraphCircuit { check_mode: CheckMode, ) -> Result> { // placeholder dummy inputs - must call prepare_public_inputs to load data afterwards - let mut inputs: Vec> = vec![]; + let mut inputs: Vec> = vec![]; for shape in model.graph.input_shapes() { - let t: Tensor = Tensor::new(None, &shape).unwrap(); + let t: Vec = vec![Fp::zero(); shape.iter().product::()]; inputs.push(t); } - // dummy module settings, must load from GraphInput after + // dummy module settings, must load from GraphData after let module_settings = ModuleSettings::default(); settings.check_mode = check_mode; Ok(GraphCircuit { model, - inputs, - outputs: vec![], + graph_witness: GraphWitness::new(inputs, vec![]), settings, module_settings, }) } - #[cfg(target_arch = "wasm32")] /// load inputs and outputs for the model pub fn load_graph_witness( &mut self, data: &GraphWitness, ) -> Result<(), Box> { - self.inputs = - self.process_witness_source(&data.input_data, self.model.graph.input_shapes())?; - self.outputs = - self.process_witness_source(&data.output_data, self.model.graph.output_shapes())?; + self.graph_witness = data.clone(); // load the module settings self.module_settings = ModuleSettings::from(data); Ok(()) } - #[cfg(not(target_arch = "wasm32"))] - /// load inputs and outputs for the model - pub async fn load_graph_witness( - &mut self, - data: &GraphWitness, - test_on_chain_data: Option, - ) -> Result<(), Box> { - let mut data = data.clone(); - - // mutate it if need be - if let Some(test_path) = test_on_chain_data { - self.populate_on_chain_test_data(&mut data, test_path) - .await?; - } else { - self.inputs = self - .process_witness_source( - &data.input_data, - self.model.graph.input_shapes(), - self.model.graph.get_input_scales(), - ) - .await?; - self.outputs = self - .process_witness_source( - &data.output_data, - self.model.graph.output_shapes(), - self.model.graph.get_output_scales(), - ) - .await?; - } - - // load the module settings - self.module_settings = ModuleSettings::from(&data); - - Ok(()) - } - /// Prepare the public inputs for the circuit. pub fn prepare_public_inputs( &mut self, @@ -359,10 +539,10 @@ impl GraphCircuit { // as they are configured in that order as Column let mut public_inputs = vec![]; if self.settings.run_args.input_visibility.is_public() { - public_inputs = self.inputs.clone(); + public_inputs = self.graph_witness.inputs.0.clone(); } if self.settings.run_args.output_visibility.is_public() { - public_inputs.extend(self.outputs.clone()); + public_inputs.extend(self.graph_witness.outputs.0.clone()); } info!( "public inputs lengths: {:?}", @@ -392,27 +572,23 @@ impl GraphCircuit { #[cfg(target_arch = "wasm32")] pub fn load_graph_input( &mut self, - data: &GraphInput, - ) -> Result<(), Box> { + data: &GraphData, + ) -> Result>, Box> { let shapes = self.model.graph.input_shapes(); let scales = vec![self.settings.run_args.scale; shapes.len()]; - self.inputs = self.process_data_source(&data.input_data, shapes, scales)?; - Ok(()) + self.process_data_source(&data.input_data, shapes, scales) } /// #[cfg(not(target_arch = "wasm32"))] pub async fn load_graph_input( &mut self, - data: &GraphInput, - ) -> Result<(), Box> { + data: &GraphData, + ) -> Result>, Box> { let shapes = self.model.graph.input_shapes(); let scales = vec![self.settings.run_args.scale; shapes.len()]; - self.inputs = self - .process_data_source(&data.input_data, shapes, scales) - .await?; - - Ok(()) + self.process_data_source(&data.input_data, shapes, scales) + .await } #[cfg(target_arch = "wasm32")] @@ -431,21 +607,6 @@ impl GraphCircuit { } } - #[cfg(target_arch = "wasm32")] - /// Process the data source for the model - fn process_witness_source( - &mut self, - data: &WitnessSource, - shapes: Vec>, - ) -> Result>, Box> { - match &data { - WitnessSource::OnChain(_) => { - panic!("Cannot use on-chain data source as input for wasm rn.") - } - WitnessSource::File(file_data) => self.load_witness_file_data(file_data, &shapes), - } - } - #[cfg(not(target_arch = "wasm32"))] /// Process the data source for the model async fn process_data_source( @@ -467,27 +628,6 @@ impl GraphCircuit { } } - #[cfg(not(target_arch = "wasm32"))] - /// Process the data source for the model - async fn process_witness_source( - &mut self, - data: &WitnessSource, - shapes: Vec>, - scales: Vec, - ) -> Result>, Box> { - match &data { - WitnessSource::OnChain(source) => { - let mut per_item_scale = vec![]; - for (i, shape) in shapes.iter().enumerate() { - per_item_scale.extend(vec![scales[i]; shape.iter().product::()]); - } - self.load_on_chain_data(source.clone(), &shapes, per_item_scale) - .await - } - WitnessSource::File(file_data) => self.load_witness_file_data(file_data, &shapes), - } - } - /// Prepare on chain test data #[cfg(not(target_arch = "wasm32"))] pub async fn load_on_chain_data( @@ -545,7 +685,7 @@ impl GraphCircuit { ) -> Result>, Box> { // quantize the supplied data using the provided scale. let mut data: Vec> = vec![]; - for (d, shape) in file_data.iter().zip(shapes) { + for (d, shape) in file_data.0.iter().zip(shapes) { let mut t: Tensor = d.clone().into_iter().into(); t.reshape(shape); data.push(t); @@ -554,22 +694,23 @@ impl GraphCircuit { } /// Calibrate the circuit to the supplied data. - pub fn calibrate(&mut self) -> Result<(), Box> { - let res = self.forward()?; + pub fn calibrate(&mut self, input: &[Tensor]) -> Result<(), Box> { + let res = self.forward(input)?; + let max_range = 2i128.pow(self.settings.run_args.bits as u32 - 1); - if res.max_lookup_input > max_range { - let recommended_bits = (res.max_lookup_input as f64).log2().ceil() as usize + 1; + if res.max_lookup_inputs > max_range { + let recommended_bits = (res.max_lookup_inputs as f64).log2().ceil() as usize + 1; if recommended_bits <= (MAX_PUBLIC_SRS - 1) as usize { self.settings.run_args.bits = recommended_bits; self.settings.run_args.logrows = (recommended_bits + 1) as u32; - return self.calibrate(); + return self.calibrate(input); } else { let err_string = format!("No possible value of bits (estimate {}) at scale {} can accomodate this value.", recommended_bits, self.settings.run_args.scale); return Err(err_string.into()); } } else { - let min_bits = (res.max_lookup_input as f64).log2().ceil() as usize + 1; + let min_bits = (res.max_lookup_inputs as f64).log2().ceil() as usize + 1; let min_rows_from_constraints = (self.settings.num_constraints as f64 + ASSUMED_BLINDING_FACTORS as f64) @@ -627,14 +768,17 @@ impl GraphCircuit { } /// Runs the forward pass of the model / graph of computations and any associated hashing. - pub fn forward(&self) -> Result> { + pub fn forward( + &self, + inputs: &[Tensor], + ) -> Result> { let visibility = VarVisibility::from_args(self.settings.run_args)?; let mut processed_inputs = None; let mut processed_params = None; let mut processed_outputs = None; if visibility.input.requires_processing() { - processed_inputs = Some(GraphModules::forward(&self.inputs, visibility.input)?); + processed_inputs = Some(GraphModules::forward(inputs, visibility.input)?); } if visibility.params.requires_processing() { @@ -649,19 +793,32 @@ impl GraphCircuit { )?); } - let outputs = self.model.forward(&self.inputs)?; + let model_results = self.model.forward(inputs)?; if visibility.output.requires_processing() { - processed_outputs = Some(GraphModules::forward(&outputs.outputs, visibility.output)?); + processed_outputs = Some(GraphModules::forward( + &model_results.outputs, + visibility.output, + )?); } - Ok(ForwardResult { - inputs: self.inputs.clone(), - outputs: outputs.outputs, + Ok(GraphWitness { + inputs: inputs + .to_vec() + .iter() + .map(|t| t.deref().to_vec()) + .collect_vec() + .into(), + outputs: model_results + .outputs + .iter() + .map(|t| t.deref().to_vec()) + .collect_vec() + .into(), processed_inputs, processed_params, processed_outputs, - max_lookup_input: outputs.max_lookup_inputs, + max_lookup_inputs: model_results.max_lookup_inputs, }) } @@ -687,12 +844,13 @@ impl GraphCircuit { /// #[cfg(not(target_arch = "wasm32"))] - async fn populate_on_chain_test_data( + pub async fn populate_on_chain_test_data( &mut self, - data: &mut GraphWitness, + data: &mut GraphData, test_on_chain_data: TestOnChainData, ) -> Result<(), Box> { // Set up local anvil instance for reading on-chain data + if matches!( test_on_chain_data.data_sources.input, TestDataSource::OnChain @@ -703,8 +861,8 @@ impl GraphCircuit { } let input_data = match &data.input_data { - WitnessSource::File(input_data) => input_data, - WitnessSource::OnChain(_) => panic!( + DataSource::File(input_data) => input_data, + DataSource::OnChain(_) => panic!( "Cannot use on-chain data source as input for on-chain test. Will manually populate on-chain data from file source instead" ), @@ -719,16 +877,7 @@ impl GraphCircuit { test_on_chain_data.rpc.as_deref(), ) .await?; - self.inputs = datam.0; data.input_data = datam.1.into(); - } else { - self.inputs = self - .process_witness_source( - &data.input_data, - self.model.graph.input_shapes(), - self.model.graph.get_input_scales(), - ) - .await?; } if matches!( test_on_chain_data.data_sources.output, @@ -740,11 +889,12 @@ impl GraphCircuit { } let output_data = match &data.output_data { - WitnessSource::File(output_data) => output_data, - WitnessSource::OnChain(_) => panic!( + Some(DataSource::File(output_data)) => output_data, + Some(DataSource::OnChain(_)) => panic!( "Cannot use on-chain data source as output for on-chain test. Will manually populate on-chain data from file source instead" ), + _ => panic!("No output data to populate"), }; let datum: (Vec>, OnChainSource) = OnChainSource::test_from_file_data( output_data, @@ -753,18 +903,9 @@ impl GraphCircuit { test_on_chain_data.rpc.as_deref(), ) .await?; - self.outputs = datum.0; - data.output_data = datum.1.into(); - } else { - self.outputs = self - .process_witness_source( - &data.input_data, - self.model.graph.input_shapes(), - self.model.graph.get_output_scales(), - ) - .await?; + data.output_data = Some(datum.1.into()); } - // Save the updated GraphInput struct to the data_path + // Save the updated GraphData struct to the data_path data.save(test_on_chain_data.data)?; Ok(()) } @@ -831,9 +972,10 @@ impl Circuit for GraphCircuit { ) -> Result<(), PlonkError> { trace!("Setting input in synthesize"); let mut inputs = self - .inputs + .graph_witness + .get_input_tensor() .iter() - .map(|i| ValTensor::from(i.map(|x| Value::known(x)))) + .map(|i| ValTensor::from(i.map(Value::known))) .collect::>>(); let mut instance_offset = ModuleInstanceOffset::new(); diff --git a/src/graph/model.rs b/src/graph/model.rs index b184dfabc..691c82df3 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -605,6 +605,17 @@ impl Model { let start_time = instant::Instant::now(); + // reshape inputs + let inputs: Vec> = inputs + .iter() + .zip(self.graph.input_shapes().iter()) + .map(|(input, shape)| { + let mut t = input.clone(); + t.reshape(shape).unwrap(); + t + }) + .collect_vec(); + let mut results = BTreeMap::>::new(); for (i, input_idx) in self.graph.inputs.iter().enumerate() { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index cbaff50c7..d55bc0ab4 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -3,7 +3,7 @@ mod native_tests { use core::panic; - use ezkl_lib::graph::input::{FileSource, GraphInput}; + use ezkl_lib::graph::input::{FileSource, GraphData}; use ezkl_lib::graph::DataSource; use lazy_static::lazy_static; use std::env::var; @@ -12,6 +12,7 @@ mod native_tests { use tempdir::TempDir; static COMPILE: Once = Once::new(); static KZG19: Once = Once::new(); + static KZG17: Once = Once::new(); static KZG23: Once = Once::new(); static KZG26: Once = Once::new(); static START_ANVIL: Once = Once::new(); @@ -58,6 +59,20 @@ mod native_tests { }); } + fn init_params_17() { + KZG17.call_once(|| { + let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) + .args([ + "gen-srs", + &format!("--srs-path={}/kzg17.srs", TEST_DIR.path().to_str().unwrap()), + "--logrows=17", + ]) + .status() + .expect("failed to execute process"); + assert!(status.success()); + }); + } + fn init_params_23() { KZG23.call_once(|| { let status = Command::new("curl") @@ -126,7 +141,7 @@ mod native_tests { assert!(status.success()); - let data = GraphInput::from_path(format!("{}/{}/input.json", test_dir, test).into()) + let data = GraphData::from_path(format!("{}/{}/input.json", test_dir, test).into()) .expect("failed to load input data"); let input_data = match data.input_data { @@ -139,7 +154,7 @@ mod native_tests { .map(|data| (0..num_batches).flat_map(|_| data.clone()).collect()) .collect(); - let duplicated_data = GraphInput::new(DataSource::File(duplicated_input_data)); + let duplicated_data = GraphData::new(DataSource::File(duplicated_input_data)); let res = duplicated_data.save(format!("{}/{}/input.json", test_dir, output_dir).into()); @@ -424,17 +439,17 @@ mod native_tests { #(#[test_case(TESTS[N])])* fn kzg_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); - kzg_prove_and_verify(test.to_string(), 19, "safe", "private", "private", "public"); + kzg_prove_and_verify(test.to_string(), 17, "safe", "private", "private", "public"); } #(#[test_case(TESTS[N])])* fn kzg_prove_and_verify_hashed_output(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); - kzg_prove_and_verify(test.to_string(), 19, "safe", "private", "private", "hashed"); + kzg_prove_and_verify(test.to_string(), 17, "safe", "private", "private", "hashed"); } #(#[test_case(TESTS[N])])* @@ -517,10 +532,10 @@ mod native_tests { #(#[test_case(TESTS_ON_CHAIN_INPUT[N])])* fn kzg_evm_on_chain_input_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_on_chain_input_prove_and_verify(test.to_string(), 200, "on-chain", "file"); + kzg_evm_on_chain_input_prove_and_verify(test.to_string(), 200, "on-chain", "file", 17); } }); @@ -528,10 +543,10 @@ mod native_tests { #(#[test_case(TESTS_ON_CHAIN_INPUT[N])])* fn kzg_evm_on_chain_output_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_on_chain_input_prove_and_verify(test.to_string(), 200, "file", "on-chain"); + kzg_evm_on_chain_input_prove_and_verify(test.to_string(), 200, "file", "on-chain", 17); } }); @@ -540,10 +555,10 @@ mod native_tests { #(#[test_case(TESTS_ON_CHAIN_INPUT[N])])* fn kzg_evm_on_chain_input_output_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_on_chain_input_prove_and_verify(test.to_string(), 200, "on-chain", "on-chain"); + kzg_evm_on_chain_input_prove_and_verify(test.to_string(), 200, "on-chain", "on-chain", 17); } }); @@ -553,37 +568,37 @@ mod native_tests { #(#[test_case(TESTS_EVM[N])])* fn kzg_evm_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_prove_and_verify(test.to_string(), "private", "private", "public", 1); + kzg_evm_prove_and_verify(test.to_string(), "private", "private", "public", 1, 17); } #(#[test_case(TESTS_EVM[N])])* fn kzg_evm_hashed_input_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_prove_and_verify(test.to_string(), "hashed", "private", "private", 1); + kzg_evm_prove_and_verify(test.to_string(), "hashed", "private", "private", 1, 17); } #(#[test_case(TESTS_EVM[N])])* fn kzg_evm_hashed_params_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_prove_and_verify(test.to_string(), "private", "hashed", "public", 1); + kzg_evm_prove_and_verify(test.to_string(), "private", "hashed", "public", 1, 17); } #(#[test_case(TESTS_EVM[N])])* fn kzg_evm_hashed_output_prove_and_verify_(test: &str) { crate::native_tests::init_binary(); - crate::native_tests::init_params_19(); + crate::native_tests::init_params_17(); crate::native_tests::mv_test_(test); crate::native_tests::start_anvil(); - kzg_evm_prove_and_verify(test.to_string(), "private", "private", "hashed", 1); + kzg_evm_prove_and_verify(test.to_string(), "private", "private", "hashed", 1, 17); } @@ -1345,6 +1360,7 @@ mod native_tests { param_visibility: &str, output_visibility: &str, num_runs: usize, + logrows: usize, ) { let test_dir = TEST_DIR.path().to_str().unwrap(); let anvil_url = ANVIL_URL.as_str(); @@ -1409,7 +1425,7 @@ mod native_tests { &format!("{}/{}/key.pk", test_dir, example_name), "--vk-path", &format!("{}/{}/key.vk", test_dir, example_name), - &format!("--srs-path={}/kzg19.srs", test_dir), + &format!("--srs-path={}/kzg{}.srs", test_dir, logrows), &format!( "--settings-path={}/{}/settings.json", test_dir, example_name @@ -1430,7 +1446,7 @@ mod native_tests { &format!("{}/{}/proof.pf", test_dir, example_name), "--pk-path", &format!("{}/{}/key.pk", test_dir, example_name), - &format!("--srs-path={}/kzg19.srs", TEST_DIR.path().to_str().unwrap()), + &format!("--srs-path={}/kzg{}.srs", test_dir, logrows), "--transcript=evm", "--strategy=single", &format!( @@ -1448,7 +1464,7 @@ mod native_tests { ); let code_arg = format!("{}/{}/deployment.code", test_dir, example_name); let vk_arg = format!("{}/{}/key.vk", test_dir, example_name); - let param_arg = format!("--srs-path={}/kzg19.srs", test_dir); + let param_arg = format!("--srs-path={}/kzg{}.srs", test_dir, logrows); let opt_arg = format!("--optimizer-runs={}", num_runs); let rpc_arg = format!("--rpc-url={}", anvil_url); let addr_path_arg = format!("--addr-path={}/{}/addr.txt", test_dir, example_name); @@ -1531,6 +1547,7 @@ mod native_tests { num_runs: usize, input_source: &str, output_source: &str, + logrows: usize, ) { let test_dir = TEST_DIR.path().to_str().unwrap(); @@ -1559,16 +1576,27 @@ mod native_tests { .expect("failed to execute process"); assert!(status.success()); + let data_path = format!("{}/{}/input.json", test_dir, example_name); + let witness_path = format!("{}/{}/witness.json", test_dir, example_name); + let test_on_chain_data_path = format!("{}/{}/on_chain_input.json", test_dir, example_name); + let rpc_arg = format!("--rpc-url={}", ANVIL_URL.as_str()); + + let test_input_source = format!("--input-source={}", input_source); + let test_output_source = format!("--output-source={}", output_source); + let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "gen-witness", + "setup-test-evm-data", "-D", - format!("{}/{}/input.json", test_dir, example_name).as_str(), + data_path.as_str(), "-M", model_path.as_str(), circuit_settings.as_str(), - "-O", - format!("{}/{}/input.json", test_dir, example_name).as_str(), + "--test-data", + test_on_chain_data_path.as_str(), + rpc_arg.as_str(), + test_input_source.as_str(), + test_output_source.as_str(), ]) .status() .expect("failed to execute process"); @@ -1576,40 +1604,30 @@ mod native_tests { let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "setup", + "gen-witness", + "-D", + test_on_chain_data_path.as_str(), "-M", model_path.as_str(), - "--pk-path", - &format!("{}/{}/key.pk", test_dir, example_name), - "--vk-path", - &format!("{}/{}/key.vk", test_dir, example_name), - &format!("--srs-path={}/kzg19.srs", test_dir), circuit_settings.as_str(), + "-O", + witness_path.as_str(), ]) .status() .expect("failed to execute process"); assert!(status.success()); - let data_path = format!("{}/{}/input.json", test_dir, example_name); - let test_on_chain_data_path = format!("{}/{}/on_chain_input.json", test_dir, example_name); - let rpc_arg = format!("--rpc-url={}", ANVIL_URL.as_str()); - - let test_input_source = format!("--input-source={}", input_source); - let test_output_source = format!("--output-source={}", output_source); - let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ - "setup-test-evm-witness", - "-W", - data_path.as_str(), + "setup", "-M", model_path.as_str(), + "--pk-path", + &format!("{}/{}/key.pk", test_dir, example_name), + "--vk-path", + &format!("{}/{}/key.vk", test_dir, example_name), + &format!("--srs-path={}/kzg{}.srs", test_dir, logrows), circuit_settings.as_str(), - "--test-witness", - test_on_chain_data_path.as_str(), - rpc_arg.as_str(), - test_input_source.as_str(), - test_output_source.as_str(), ]) .status() .expect("failed to execute process"); @@ -1619,14 +1637,14 @@ mod native_tests { .args([ "prove", "-W", - test_on_chain_data_path.as_str(), + witness_path.as_str(), "-M", model_path.as_str(), "--proof-path", &format!("{}/{}/proof.pf", test_dir, example_name), "--pk-path", &format!("{}/{}/key.pk", test_dir, example_name), - &format!("--srs-path={}/kzg19.srs", TEST_DIR.path().to_str().unwrap()), + &format!("--srs-path={}/kzg{}.srs", test_dir, logrows), "--transcript=evm", "--strategy=single", circuit_settings.as_str(), @@ -1640,7 +1658,7 @@ mod native_tests { test_dir, example_name ); let vk_arg = format!("{}/{}/key.vk", test_dir, example_name); - let param_arg = format!("--srs-path={}/kzg19.srs", test_dir); + let param_arg = format!("--srs-path={}/kzg{}.srs", test_dir, logrows); let opt_arg = format!("--optimizer-runs={}", num_runs); @@ -1658,7 +1676,7 @@ mod native_tests { param_arg.as_str(), "--vk-path", vk_arg.as_str(), - "-W", + "-D", test_on_chain_data_path.as_str(), opt_arg.as_str(), ]) @@ -1671,7 +1689,7 @@ mod native_tests { .args([ "deploy-evm-da-verifier", circuit_settings.as_str(), - "-W", + "-D", test_on_chain_data_path.as_str(), "--sol-code-path", sol_arg.as_str(), diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index ba8f43542..5328071b8 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -153,8 +153,8 @@ def test_forward(): with open(output_path, "r") as f: data = json.load(f) - assert data["input_data"] == res["input_data"] - assert data["output_data"] == res["output_data"] + assert data["inputs"] == res["inputs"] + assert data["outputs"] == res["outputs"] assert data["processed_inputs"]["poseidon_hash"] == res["processed_inputs"]["poseidon_hash"] == [[ 8270957937025516140, 11801026918842104328, 2203849898884507041, 140307258138425306]] diff --git a/tests/wasm/test.witness.json b/tests/wasm/test.witness.json index 0173c9646..7ad73b1b4 100644 --- a/tests/wasm/test.witness.json +++ b/tests/wasm/test.witness.json @@ -1,5 +1,5 @@ { - "input_data": [ + "inputs": [ [ [ 2, @@ -21,7 +21,7 @@ ] ] ], - "output_data": [ + "outputs": [ [ [ 0, @@ -48,5 +48,6 @@ 0 ] ] - ] + ], + "max_lookup_inputs": 0 } \ No newline at end of file