diff --git a/abis/DataAttestation.json b/abis/DataAttestation.json index 729cb3266..8a90dcd03 100644 --- a/abis/DataAttestation.json +++ b/abis/DataAttestation.json @@ -18,7 +18,7 @@ }, { "internalType": "uint256[]", - "name": "_outputScales", + "name": "_scales", "type": "uint256[]" }, { @@ -35,19 +35,6 @@ "stateMutability": "nonpayable", "type": "constructor" }, - { - "inputs": [], - "name": "INPUT_SCALE", - "outputs": [ - { - "internalType": "uint256", - "name": "", - "type": "uint256" - } - ], - "stateMutability": "view", - "type": "function" - }, { "inputs": [ { @@ -106,7 +93,7 @@ "type": "uint256" } ], - "name": "outputScales", + "name": "scales", "outputs": [ { "internalType": "uint256", diff --git a/contracts/AttestData.sol b/contracts/AttestData.sol index 106703e3d..4dc273133 100644 --- a/contracts/AttestData.sol +++ b/contracts/AttestData.sol @@ -28,8 +28,7 @@ contract DataAttestation { } AccountCall[] public accountCalls; - uint public constant INPUT_SCALE = 1 << 0; - uint[] public outputScales; + uint[] public scales; address public admin; @@ -55,13 +54,13 @@ contract DataAttestation { address[] memory _contractAddresses, bytes[][] memory _callData, uint256[][] memory _decimals, - uint[] memory _outputScales, + uint[] memory _scales, uint8 _instanceOffset, address _admin ) { admin = _admin; - for (uint i; i < _outputScales.length; i++) { - outputScales.push(1 << _outputScales[i]); + for (uint i; i < _scales.length; i++) { + scales.push(1 << _scales[i]); } populateAccountCalls(_contractAddresses, _callData, _decimals); instanceOffset = _instanceOffset; @@ -239,10 +238,7 @@ contract DataAttestation { account, accountCalls[i].callData[j] ); - uint256 scale = INPUT_SCALE; - if (counter >= INPUT_CALLS) { - scale = outputScales[counter - INPUT_CALLS]; - } + uint256 scale = scales[counter]; int256 quantized_data = quantizeData( returnData, accountCalls[i].decimals[j], diff --git a/examples/onnx/mnist_gan/settings.json b/examples/onnx/mnist_gan/settings.json index 59c2afdbe..cf13ad789 100644 --- a/examples/onnx/mnist_gan/settings.json +++ b/examples/onnx/mnist_gan/settings.json @@ -1 +1 @@ -{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"bits":16,"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Public"},"num_constraints":8928903,"total_const_size":8753605,"model_instance_shapes":[[1,28,28]],"model_output_scales":[42],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Sigmoid":{"scale":4398046500000.0}},{"Exp":{"scale":2097152.0}},{"Exp":{"scale":34359740000.0}},{"GreaterThan":{"a":0.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null} \ No newline at end of file +{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"bits":16,"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Public"},"num_constraints":8928903,"total_const_size":8753605,"model_instance_shapes":[[1,28,28]],"model_output_scales":[42],"model_input_scales":[7],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Sigmoid":{"scale":4398046500000.0}},{"Exp":{"scale":2097152.0}},{"Exp":{"scale":34359740000.0}},{"GreaterThan":{"a":0.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null} \ No newline at end of file diff --git a/examples/onnx/variable_cnn/settings.json b/examples/onnx/variable_cnn/settings.json index 97e67f66f..2248a8fc7 100644 --- a/examples/onnx/variable_cnn/settings.json +++ b/examples/onnx/variable_cnn/settings.json @@ -1 +1 @@ -{"run_args":{"tolerance":{"val":0.0,"scales":[1,1]},"input_scale":11,"param_scale":11,"scale_rebase_multiplier":1,"bits":25,"logrows":26,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":176820,"total_const_size":0,"model_instance_shapes":[[1,100]],"model_output_scales":[11],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":2048.0}},"ReLU"],"check_mode":"UNSAFE","version":"0.0.0"} \ No newline at end of file +{"run_args":{"tolerance":{"val":0.0,"scales":[1,1]},"input_scale":11,"param_scale":11,"scale_rebase_multiplier":1,"bits":25,"logrows":26,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":176820,"total_const_size":0,"model_instance_shapes":[[1,100]],"model_output_scales":[11],"model_input_scales":[11],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":2048.0}},"ReLU"],"check_mode":"UNSAFE","version":"0.0.0"} \ No newline at end of file diff --git a/src/eth.rs b/src/eth.rs index edf21ca3d..a65d2ed8e 100644 --- a/src/eth.rs +++ b/src/eth.rs @@ -126,8 +126,17 @@ pub async fn deploy_da_verifier_via_solidity( let mut contract_instance_offset = 0; if let DataSource::OnChain(source) = input.input_data { + let input_scales = settings.model_input_scales; for call in source.calls { calls_to_accounts.push(call); + } + + // give each input a scale + for scale in input_scales { + scales.extend(vec![ + scale; + instance_shapes[instance_idx].iter().product::() + ]); instance_idx += 1; } } else if let DataSource::File(source) = input.input_data { @@ -644,7 +653,7 @@ pub fn get_contract_artifacts( /// Sets the constants stored in the da verifier pub fn fix_da_sol( - input_data: Option<(u32, Vec)>, + input_data: Option>, output_data: Option>, ) -> Result> { @@ -653,14 +662,8 @@ pub fn fix_da_sol( // fill in the quantization params and total calls // as constants to the contract to save on gas if let Some(input_data) = input_data { - let input_calls: usize = input_data.1.iter().map(|v| v.call_data.len()).sum(); - let input_scale = input_data.0; - accounts_len = input_data.1.len(); - contract = contract.replace( - "uint public constant INPUT_SCALE = 1 << 0;", - &format!("uint public constant INPUT_SCALE = 1 << {};", input_scale), - ); - + let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum(); + accounts_len = input_data.len(); contract = contract.replace( "uint256 constant INPUT_CALLS = 0;", &format!("uint256 constant INPUT_CALLS = {};", input_calls), diff --git a/src/execute.rs b/src/execute.rs index cb5a4a742..2e96ab62f 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -169,7 +169,7 @@ pub async fn run(cli: Cli) -> Result<(), Box> { sol_code_path, abi_path, data, - } => create_evm_data_attestation_verifier( + } => create_evm_data_attestation( vk_path, srs_path, settings_path, @@ -615,6 +615,7 @@ pub(crate) async fn calibrate( run_args: found_run_args, required_lookups: settings.required_lookups, model_output_scales: settings.model_output_scales, + model_input_scales: settings.model_input_scales, num_constraints: settings.num_constraints, total_const_size: settings.total_const_size, ..original_settings.clone() @@ -843,7 +844,7 @@ pub(crate) fn create_evm_verifier( } #[cfg(not(target_arch = "wasm32"))] -pub(crate) fn create_evm_data_attestation_verifier( +pub(crate) fn create_evm_data_attestation( vk_path: PathBuf, srs_path: PathBuf, settings_path: PathBuf, @@ -893,7 +894,7 @@ pub(crate) fn create_evm_data_attestation_verifier( for call in source.calls { on_chain_input_data.push(call); } - Some((settings.run_args.input_scale, on_chain_input_data)) + Some(on_chain_input_data) } else { None }; diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 78b443661..d40c03cd5 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -311,6 +311,8 @@ pub struct GraphSettings { pub model_instance_shapes: Vec>, /// model output scales pub model_output_scales: Vec, + /// model input scales + pub model_input_scales: Vec, /// the of instance cells used by modules pub module_sizes: ModuleSizes, /// required_lookups diff --git a/src/graph/model.rs b/src/graph/model.rs index a5690e483..952a0bda0 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -460,6 +460,7 @@ impl Model { num_constraints, required_lookups: lookup_ops, model_output_scales: self.graph.get_output_scales(), + model_input_scales: self.graph.get_input_scales(), total_const_size, check_mode, version: env!("CARGO_PKG_VERSION").to_string(), diff --git a/src/python.rs b/src/python.rs index f7241a3a2..38455d8c0 100644 --- a/src/python.rs +++ b/src/python.rs @@ -916,7 +916,7 @@ fn create_evm_data_attestation( abi_path: PathBuf, input_data: PathBuf, ) -> Result { - crate::execute::create_evm_data_attestation_verifier( + crate::execute::create_evm_data_attestation( vk_path, srs_path, settings_path, @@ -925,7 +925,7 @@ fn create_evm_data_attestation( input_data, ) .map_err(|e| { - let err_str = format!("Failed to run create_evm_data_attestation_verifier: {}", e); + let err_str = format!("Failed to run create_evm_data_attestation: {}", e); PyRuntimeError::new_err(err_str) })?; diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index b0776caf9..c4cb63428 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -1815,7 +1815,7 @@ mod native_tests { .status() .expect("failed to execute process"); assert!(status.success()); - + let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) .args([ "gen-witness", diff --git a/tests/wasm/settings.json b/tests/wasm/settings.json index aab22c42c..75f35954a 100644 --- a/tests/wasm/settings.json +++ b/tests/wasm/settings.json @@ -25,6 +25,9 @@ "model_output_scales": [ 7 ], + "model_input_scales": [ + 20 + ], "model_instance_shapes": [ [ 1,