diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b003f522d..426540fcd 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -99,7 +99,6 @@ jobs: run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture --features icicle -- --include-ignored - name: Conv + relu overflow run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --features icicle -- --include-ignored - ultra-overflow-tests_og-lookup: runs-on: non-gpu @@ -134,7 +133,6 @@ jobs: - name: Conv + relu overflow run: cargo nextest run --release conv_relu_col_ultra_overflow --no-capture --no-default-features --features ezkl -- --include-ignored - ultra-overflow-tests: runs-on: non-gpu steps: @@ -629,6 +627,8 @@ jobs: run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_fixed_params_ - name: Public outputs run: source .env/bin/activate; cargo nextest run --release --verbose tests::accuracy_measurement_public_outputs_ + - name: Public outputs + resources + run: source .env/bin/activate; cargo nextest run --release --verbose tests::resources_accuracy_measurement_public_outputs_ python-integration-tests: runs-on: diff --git a/.gitignore b/.gitignore index ae3a37966..1a8583ec0 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,6 @@ var/ *.whl *.bak node_modules -timingData.json \ No newline at end of file +timingData.json +!tests/wasm/pk.key +!tests/wasm/vk.key \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 1df9921b6..dc0466811 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2181,7 +2181,7 @@ dependencies = [ [[package]] name = "halo2_gadgets" version = "0.2.0" -source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#066c18e149e75ec97b5eb390a109a867cfb4e35e" +source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#aa437fbe92ae2e1d7d9993080cffb5338cbe3686" dependencies = [ "arrayvec 0.7.4", "bitvec 1.0.1", @@ -2198,7 +2198,7 @@ dependencies = [ [[package]] name = "halo2_proofs" version = "0.2.0" -source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#066c18e149e75ec97b5eb390a109a867cfb4e35e" +source = "git+https://github.com/zkonduit/halo2?branch=ac/lookup-modularity#aa437fbe92ae2e1d7d9993080cffb5338cbe3686" dependencies = [ "blake2b_simd", "env_logger", diff --git a/examples/notebooks/nbeats_timeseries_forecasting.ipynb b/examples/notebooks/nbeats_timeseries_forecasting.ipynb index 10e37e91e..5433cd279 100644 --- a/examples/notebooks/nbeats_timeseries_forecasting.ipynb +++ b/examples/notebooks/nbeats_timeseries_forecasting.ipynb @@ -834,12 +834,18 @@ }, "outputs": [], "source": [ + "run_args = ezkl.PyRunArgs()\n", + "run_args.input_visibility = \"private\"\n", + "run_args.param_visibility = \"fixed\"\n", + "run_args.output_visibility = \"public\"\n", + "run_args.variables = [(\"batch_size\", 1)]\n", + "\n", "!RUST_LOG=trace\n", "# TODO: Dictionary outputs\n", "res = ezkl.gen_settings(model_path, settings_path)\n", "assert res == True\n", "\n", - "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\")\n", + "res = ezkl.calibrate_settings(data_path, model_path, settings_path, \"resources\", max_logrows = 20, scales = [5,6])\n", "assert res == True" ] }, @@ -965,9 +971,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index bac708776..ca513a407 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -62,6 +62,7 @@ pub fn dot( let global_start = instant::Instant::now(); let mut values = values.clone(); + // this section has been optimized to death, don't mess with it let mut removal_indices = values[0].get_const_zero_indices()?; let second_zero_indices = values[1].get_const_zero_indices()?; @@ -104,6 +105,7 @@ pub fn dot( }; inputs.push(inp); } + let elapsed = start.elapsed(); trace!("assigning inputs took: {:?}", elapsed); diff --git a/src/circuit/ops/region.rs b/src/circuit/ops/region.rs index 3f203a3ee..4a1aa8d07 100644 --- a/src/circuit/ops/region.rs +++ b/src/circuit/ops/region.rs @@ -59,6 +59,11 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> { } impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { + /// + pub fn increment_total_constants(&mut self, n: usize) { + self.total_constants += n; + } + /// Create a new region context pub fn new(region: Region<'a, F>, row: usize, num_inner_cols: usize) -> RegionCtx<'a, F> { let region = Some(RefCell::new(region)); @@ -291,14 +296,16 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> { ) -> Result<(ValTensor, usize), Error> { if let Some(region) = &self.region { // duplicates every nth element to adjust for column overflow - var.assign_with_duplication( + let (res, len, total_assigned_constants) = var.assign_with_duplication( &mut region.borrow_mut(), self.row, self.linear_coord, values, check_mode, single_inner_col, - ) + )?; + self.total_constants += total_assigned_constants; + Ok((res, len)) } else { let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication( self.row, diff --git a/src/execute.rs b/src/execute.rs index a46f1d9c5..63215d7db 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -623,8 +623,8 @@ pub(crate) fn calibrate( scales } else { match target { - CalibrationTarget::Resources { .. } => (4..8).collect::>(), - CalibrationTarget::Accuracy => (8..14).collect::>(), + CalibrationTarget::Resources { .. } => (8..10).collect::>(), + CalibrationTarget::Accuracy => (10..14).collect::>(), } }; diff --git a/src/graph/mod.rs b/src/graph/mod.rs index d84afb8f6..21cff33ba 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -458,6 +458,21 @@ impl GraphSettings { pub fn uses_modules(&self) -> bool { !self.module_sizes.max_constraints() > 0 } + + /// if any visibility is encrypted or hashed + pub fn module_requires_fixed(&self) -> bool { + if self.run_args.input_visibility.is_encrypted() + || self.run_args.input_visibility.is_hashed() + || self.run_args.output_visibility.is_encrypted() + || self.run_args.output_visibility.is_hashed() + || self.run_args.param_visibility.is_encrypted() + || self.run_args.param_visibility.is_hashed() + { + true + } else { + false + } + } } /// Configuration for a computational graph / model loaded from a `.onnx` file. @@ -1245,7 +1260,7 @@ impl Circuit for GraphCircuit { params.total_assignments, params.run_args.num_inner_cols, params.total_const_size, - params.uses_modules(), + params.module_requires_fixed(), ); module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone()); diff --git a/src/graph/model.rs b/src/graph/model.rs index 74fbfe795..9a2645ab9 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -1181,6 +1181,10 @@ impl Model { error!("{}", e); halo2_proofs::plonk::Error::Synthesis })?; + } else if !run_args.output_visibility.is_private() { + for output in &outputs { + thread_safe_region.increment_total_constants(output.num_constants()); + } } num_rows = thread_safe_region.row(); linear_coord = thread_safe_region.linear_coord(); @@ -1453,6 +1457,10 @@ impl Model { ) }) .collect::, _>>()?; + } else if !self.visibility.output.is_private() { + for output in &outputs { + region.increment_total_constants(output.num_constants()); + } } let duration = start_time.elapsed(); diff --git a/src/graph/vars.rs b/src/graph/vars.rs index 1a624abf7..b73a39a86 100644 --- a/src/graph/vars.rs +++ b/src/graph/vars.rs @@ -417,7 +417,7 @@ impl ModelVars { var_len: usize, num_inner_cols: usize, num_constants: usize, - uses_modules: bool, + module_requires_fixed: bool, ) -> Self { info!("number of blinding factors: {}", cs.blinding_factors()); @@ -431,7 +431,8 @@ impl ModelVars { num_inner_cols ); - let num_const_cols = VarTensor::constant_cols(cs, logrows, num_constants, uses_modules); + let num_const_cols = + VarTensor::constant_cols(cs, logrows, num_constants, module_requires_fixed); debug!("model uses {} fixed columns", num_const_cols); ModelVars { diff --git a/src/tensor/var.rs b/src/tensor/var.rs index 13cd08526..66cd5e30d 100644 --- a/src/tensor/var.rs +++ b/src/tensor/var.rs @@ -136,11 +136,11 @@ impl VarTensor { cs: &mut ConstraintSystem, logrows: usize, num_constants: usize, - uses_modules: bool, + module_requires_fixed: bool, ) -> usize { - if num_constants == 0 && !uses_modules { + if num_constants == 0 && !module_requires_fixed { return 0; - } else if num_constants == 0 && uses_modules { + } else if num_constants == 0 && module_requires_fixed { let col = cs.fixed_column(); cs.enable_constant(col); return 1; @@ -455,7 +455,7 @@ impl VarTensor { values: &ValTensor, check_mode: &CheckMode, single_inner_col: bool, - ) -> Result<(ValTensor, usize), halo2_proofs::plonk::Error> { + ) -> Result<(ValTensor, usize, usize), halo2_proofs::plonk::Error> { let mut prev_cell = None; match values { @@ -526,6 +526,7 @@ impl VarTensor { })?.into()}; let total_used_len = res.len(); + let total_constants = res.num_constants(); res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap(); res.reshape(dims).unwrap(); @@ -544,7 +545,7 @@ impl VarTensor { )}; } - Ok((res, total_used_len)) + Ok((res, total_used_len, total_constants)) } } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 3155f5565..b63a98b5a 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -200,7 +200,7 @@ mod native_tests { "1l_prelu", ]; - const TESTS: [&str; 67] = [ + const TESTS: [&str; 66] = [ "1l_mlp", "1l_slice", "1l_concat", @@ -268,7 +268,7 @@ mod native_tests { "sklearn_mlp", "1l_mean", "rounding_ops", - "mean_as_constrain", + // "mean_as_constrain", "arange", "layernorm", ]; @@ -497,7 +497,7 @@ mod native_tests { } }); - seq!(N in 0..=66 { + seq!(N in 0..=65 { #(#[test_case(TESTS[N])])* #[ignore] @@ -515,7 +515,7 @@ mod native_tests { crate::native_tests::setup_py_env(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy"); + accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "accuracy", 1.2); test_dir.close().unwrap(); } @@ -525,7 +525,7 @@ mod native_tests { crate::native_tests::setup_py_env(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy"); + accuracy_measurement(path, test.to_string(), "private", "fixed", "private", 1, "accuracy", 1.2); test_dir.close().unwrap(); } @@ -535,7 +535,18 @@ mod native_tests { crate::native_tests::setup_py_env(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy"); + accuracy_measurement(path, test.to_string(), "public", "private", "private", 1, "accuracy", 1.2); + test_dir.close().unwrap(); + } + + + #(#[test_case(TESTS[N])])* + fn resources_accuracy_measurement_public_outputs_(test: &str) { + crate::native_tests::init_binary(); + crate::native_tests::setup_py_env(); + let test_dir = TempDir::new(test).unwrap(); + let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); + accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0); test_dir.close().unwrap(); } @@ -1451,6 +1462,7 @@ mod native_tests { output_visibility: &str, batch_size: usize, cal_target: &str, + target_perc: f32, ) { gen_circuit_settings_and_witness( test_dir, @@ -1476,6 +1488,7 @@ mod native_tests { &format!("{}/{}/input.json", test_dir, example_name), &format!("{}/{}/witness.json", test_dir, example_name), &format!("{}/{}/settings.json", test_dir, example_name), + &format!("{}", target_perc), ]) .status() .expect("failed to execute process"); diff --git a/tests/output_comparison.py b/tests/output_comparison.py index 71c411f68..13ef96f31 100644 --- a/tests/output_comparison.py +++ b/tests/output_comparison.py @@ -96,6 +96,8 @@ def compare_outputs(zk_output, onnx_output): witness_file = sys.argv[3] # settings file is fourth argument to script settings_file = sys.argv[4] + # target + target = float(sys.argv[5]) # get the ezkl output ezkl_output = get_ezkl_output(witness_file, settings_file) # get the onnx output @@ -104,4 +106,4 @@ def compare_outputs(zk_output, onnx_output): percentage_difference = compare_outputs(ezkl_output, onnx_output) # print the percentage difference print("mean percent diff: ", percentage_difference) - assert percentage_difference < 1.2, "Percentage difference is too high" + assert percentage_difference < target, "Percentage difference is too high" diff --git a/tests/py_integration_tests.rs b/tests/py_integration_tests.rs index b1b11b6d9..93132b8d6 100644 --- a/tests/py_integration_tests.rs +++ b/tests/py_integration_tests.rs @@ -167,7 +167,7 @@ mod py_tests { use super::*; - seq!(N in 0..=34 { + seq!(N in 0..=36 { #(#[test_case(TESTS[N])])* fn run_notebook_(test: &str) { @@ -184,6 +184,7 @@ mod py_tests { test_dir.close().unwrap(); anvil_child.kill().unwrap(); } + }); #[test] fn voice_notebook_() { crate::py_tests::init_binary(); @@ -207,7 +208,7 @@ mod py_tests { run_notebook(path, "nbeats_timeseries_forecasting.ipynb"); test_dir.close().unwrap(); } - }); + } }; diff --git a/tests/wasm.rs b/tests/wasm.rs index 41360f309..ac065f350 100644 --- a/tests/wasm.rs +++ b/tests/wasm.rs @@ -27,14 +27,14 @@ mod wasm32 { wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - pub const WITNESS: &[u8] = include_bytes!("../tests/wasm/test.witness.json"); - pub const NETWORK_COMPILED: &[u8] = include_bytes!("../tests/wasm/test_network.compiled"); + pub const WITNESS: &[u8] = include_bytes!("../tests/wasm/witness.json"); + pub const NETWORK_COMPILED: &[u8] = include_bytes!("../tests/wasm/network.compiled"); pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx"); pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json"); pub const PROOF: &[u8] = include_bytes!("../tests/wasm/test.proof"); pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json"); - pub const PK: &[u8] = include_bytes!("../tests/wasm/test.provekey"); - pub const VK: &[u8] = include_bytes!("../tests/wasm/test.key"); + pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key"); + pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key"); pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg"); #[wasm_bindgen_test] @@ -375,6 +375,7 @@ mod wasm32 { wasm_bindgen::Clamped(PK.to_vec()), wasm_bindgen::Clamped(SETTINGS.to_vec()), ); + assert!(pk.is_ok()); // Run settings validation on proof (should fail) let settings = settingsValidation(wasm_bindgen::Clamped(PROOF.to_vec())); diff --git a/tests/wasm/network.compiled b/tests/wasm/network.compiled index 12a1b2f79..e27adfbfe 100644 Binary files a/tests/wasm/network.compiled and b/tests/wasm/network.compiled differ diff --git a/tests/wasm/pk.key b/tests/wasm/pk.key new file mode 100644 index 000000000..a8391993e Binary files /dev/null and b/tests/wasm/pk.key differ diff --git a/tests/wasm/settings.json b/tests/wasm/settings.json index 5320d18e6..848609e7c 100644 --- a/tests/wasm/settings.json +++ b/tests/wasm/settings.json @@ -1,67 +1 @@ -{ - "run_args": { - "tolerance": { - "val": 0.0, - "scale": 1.0 - }, - "num_inner_cols": 1, - "input_scale": 2, - "param_scale": 2, - "scale_rebase_multiplier": 1, - "lookup_range": [ - -32, - 12 - ], - "logrows": 6, - "variables": [ - [ - "batch_size", - 1 - ] - ], - "input_visibility": "Private", - "output_visibility": "Public", - "param_visibility": "Private" - }, - "total_assignments": 31, - "num_rows": 31, - "total_const_size": 0, - "model_instance_shapes": [ - [ - 1, - 4 - ] - ], - "model_output_scales": [ - 2 - ], - "model_input_scales": [ - 2 - ], - "module_sizes": { - "kzg": [], - "poseidon": [ - 0, - [ - 0 - ] - ], - "elgamal": [ - 0, - [ - 0 - ] - ] - }, - "required_lookups": [ - { - "Div": { - "denom": 4.0 - } - }, - "ReLU" - ], - "check_mode": "UNSAFE", - "version": "0.0.0", - "num_blinding_factors": null -} \ No newline at end of file +{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":2,"param_scale":2,"scale_rebase_multiplier":10,"lookup_range":[-44,30],"logrows":6,"num_inner_cols":1,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_rows":27,"total_assignments":27,"total_const_size":0,"model_instance_shapes":[[1,4]],"model_output_scales":[4],"model_input_scales":[2],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":["ReLU"],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null} \ No newline at end of file diff --git a/tests/wasm/test.key b/tests/wasm/test.key deleted file mode 100644 index 529a4c06c..000000000 Binary files a/tests/wasm/test.key and /dev/null differ diff --git a/tests/wasm/test.witness.json b/tests/wasm/test.witness.json deleted file mode 100644 index 398f2c2fe..000000000 --- a/tests/wasm/test.witness.json +++ /dev/null @@ -1 +0,0 @@ -{"inputs":[[[14385415396251402209,2429374486035521128,12558163205804149944,2583518171365219058],[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574],[1949230679015292902,16913946402569752895,5177146667339417225,1571765431670520771]]],"outputs":[[[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287],[7959790035488735211,12951774245394433045,16242874202584236123,560012691975822483],[0,0,0,0],[0,0,0,0]]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":6,"min_lookup_inputs":-16} \ No newline at end of file diff --git a/tests/wasm/test_network.compiled b/tests/wasm/test_network.compiled deleted file mode 100644 index 12a1b2f79..000000000 Binary files a/tests/wasm/test_network.compiled and /dev/null differ diff --git a/tests/wasm/vk.key b/tests/wasm/vk.key new file mode 100644 index 000000000..c1b185b4d Binary files /dev/null and b/tests/wasm/vk.key differ diff --git a/tests/wasm/witness.json b/tests/wasm/witness.json new file mode 100644 index 000000000..6ee61e2a8 --- /dev/null +++ b/tests/wasm/witness.json @@ -0,0 +1 @@ +{"inputs":[[[14385415396251402209,2429374486035521128,12558163205804149944,2583518171365219058],[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574],[1949230679015292902,16913946402569752895,5177146667339417225,1571765431670520771]]],"outputs":[[[415066004289224689,11886516471525959549,3696305541684646538,3035258219084094862],[956231351009279921,10951436676983309100,2250248050743556928,1228298028208591648],[0,0,0,0],[0,0,0,0]]],"processed_inputs":null,"processed_params":null,"processed_outputs":null,"max_lookup_inputs":15,"min_lookup_inputs":-22} \ No newline at end of file