diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 68706eab7..7e5301c40 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -236,6 +236,8 @@ jobs: with: crate: cargo-nextest locked: true + - name: public outputs and tolerance > 0 + run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32 - name: public outputs + batch size == 10 run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32 - name: kzg inputs diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 55014fd8f..f37691542 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -2953,8 +2953,17 @@ pub fn range_check_percent( return enforce_equality(config, region, values); } + let mut values = [values[0].clone(), values[1].clone()]; + + values[0] = region.assign(&config.inputs[0], &values[0])?; + values[1] = region.assign(&config.inputs[1], &values[1])?; + let total_assigned_0 = values[0].len(); + let total_assigned_1 = values[1].len(); + let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1); + region.increment(total_assigned); + // Calculate the difference between the expected output and actual output - let diff = pairwise(config, region, values, BaseOp::Sub)?; + let diff = pairwise(config, region, &values, BaseOp::Sub)?; // Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor let recip = nonlinearity( diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index 0e891cf7f..3ba39da8b 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -243,10 +243,10 @@ impl Op for LookupOp { LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a), LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a), LookupOp::Sign => "SIGN".into(), - LookupOp::GreaterThan { .. } => "GREATER_THAN".into(), - LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(), - LookupOp::LessThan { .. } => "LESS_THAN".into(), - LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(), + LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a), + LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a), + LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a), + LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a), LookupOp::Recip { input_scale, output_scale, diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 48b22bede..842865b1e 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1248,7 +1248,7 @@ impl GraphCircuit { } } - let mut model_results = self.model().forward(inputs)?; + let mut model_results = self.model().forward(inputs, &self.settings().run_args)?; if visibility.output.requires_processing() { let module_outlets = visibility.output.overwrites_inputs(); diff --git a/src/graph/model.rs b/src/graph/model.rs index 964717a89..46e707230 100644 --- a/src/graph/model.rs +++ b/src/graph/model.rs @@ -556,12 +556,16 @@ impl Model { /// * `reader` - A reader for an Onnx file. /// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model. /// * `run_args` - [RunArgs] - pub fn forward(&self, model_inputs: &[Tensor]) -> Result> { + pub fn forward( + &self, + model_inputs: &[Tensor], + run_args: &RunArgs, + ) -> Result> { let valtensor_inputs: Vec> = model_inputs .iter() .map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into()) .collect(); - let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?; + let res = self.dummy_layout(&run_args, &valtensor_inputs)?; Ok(res.into()) } @@ -1371,27 +1375,26 @@ impl Model { ValType::Constant(Fp::ONE) }; - let comparator = outputs + let output_scales = self.graph.get_output_scales()?; + let res = outputs .iter() - .map(|x| { - let mut v: ValTensor = - vec![default_value.clone(); x.dims().iter().product::()].into(); - v.reshape(x.dims())?; - Ok(v) - }) - .collect::, Box>>()?; + .enumerate() + .map(|(i, output)| { + let mut tolerance = run_args.tolerance; + tolerance.scale = scale_to_multiplier(output_scales[i]).into(); + + let mut comparator: ValTensor = + vec![default_value.clone(); output.dims().iter().product::()].into(); + comparator.reshape(output.dims())?; - let _ = outputs - .iter() - .zip(comparator) - .map(|(o, c)| { dummy_config.layout( &mut region, - &[o.clone(), c], - Box::new(HybridOp::RangeCheck(run_args.tolerance)), + &[output.clone(), comparator], + Box::new(HybridOp::RangeCheck(tolerance)), ) }) - .collect::, _>>()?; + .collect::, _>>(); + res?; } else if !self.visibility.output.is_private() { for output in &outputs { region.increment_total_constants(output.num_constants()); diff --git a/src/pfsys/mod.rs b/src/pfsys/mod.rs index f62142bbb..7f3152129 100644 --- a/src/pfsys/mod.rs +++ b/src/pfsys/mod.rs @@ -728,12 +728,13 @@ where let f = File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?; let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f); - VerifyingKey::::read::<_, C>( + let vk = VerifyingKey::::read::<_, C>( &mut reader, serde_format_from_str(&EZKL_KEY_FORMAT), params, - ) - .map_err(Box::::from) + )?; + info!("done loading verification key ✅"); + Ok(vk) } /// Loads a [ProvingKey] at `path`. @@ -750,12 +751,13 @@ where let f = File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?; let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f); - ProvingKey::::read::<_, C>( + let pk = ProvingKey::::read::<_, C>( &mut reader, serde_format_from_str(&EZKL_KEY_FORMAT), params, - ) - .map_err(Box::::from) + )?; + info!("done loading proving key ✅"); + Ok(pk) } /// Saves a [ProvingKey] to `path`. @@ -772,6 +774,7 @@ where let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f); pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?; writer.flush()?; + info!("done saving proving key ✅"); Ok(()) } @@ -789,6 +792,7 @@ where let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f); vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?; writer.flush()?; + info!("done saving verification key ✅"); Ok(()) } diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 890cb75c7..6aa1595c0 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -672,7 +672,7 @@ impl ValTensor { } Ok(indices) } - ValTensor::Instance { .. } => Err(TensorError::WrongMethod), + ValTensor::Instance { .. } => Ok(vec![]), } } @@ -690,7 +690,7 @@ impl ValTensor { } Ok(indices) } - ValTensor::Instance { .. } => Err(TensorError::WrongMethod), + ValTensor::Instance { .. } => Ok(vec![]), } } @@ -709,7 +709,11 @@ impl ValTensor { *d = v.dims().to_vec(); } ValTensor::Instance { .. } => { - return Err(TensorError::WrongMethod); + if indices.is_empty() { + return Ok(()); + } else { + return Err(TensorError::WrongMethod); + } } } Ok(()) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 3ac43301b..ba4a098d6 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -477,6 +477,7 @@ mod native_tests { use crate::native_tests::kzg_fuzz; use crate::native_tests::render_circuit; use crate::native_tests::model_serialization_different_binaries; + use rand::Rng; use tempdir::TempDir; #[test] @@ -496,7 +497,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None); + mock(path, test.to_string(), "public", "fixed", "public", 1, "accuracy", None, 0.0); test_dir.close().unwrap(); } }); @@ -569,7 +570,18 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "private", "public", 1, "resources", None); + mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, 0.0); + test_dir.close().unwrap(); + } + + #(#[test_case(TESTS[N])])* + fn mock_tolerance_public_outputs_(test: &str) { + crate::native_tests::init_binary(); + let test_dir = TempDir::new(test).unwrap(); + let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); + // gen random number between 0.0 and 1.0 + let tolerance = rand::thread_rng().gen_range(0.0..1.0) * 100.0; + mock(path, test.to_string(), "private", "private", "public", 1, "resources", None, tolerance); test_dir.close().unwrap(); } @@ -580,7 +592,7 @@ mod native_tests { let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); let large_batch_dir = &format!("large_batches_{}", test); crate::native_tests::mk_data_batches_(path, test, &large_batch_dir, 10); - mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None); + mock(path, large_batch_dir.to_string(), "private", "private", "public", 10, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -589,7 +601,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "public", "private", "private", 1, "resources", None); + mock(path, test.to_string(), "public", "private", "private", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -598,7 +610,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None); + mock(path, test.to_string(), "fixed", "private", "private", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -607,7 +619,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None); + mock(path, test.to_string(), "private", "private", "fixed", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -616,7 +628,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None); + mock(path, test.to_string(), "private", "fixed", "private", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -625,7 +637,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None); + mock(path, test.to_string(), "hashed", "private", "public", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -634,7 +646,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None); + mock(path, test.to_string(), "kzgcommit", "private", "public", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -644,7 +656,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None); + mock(path, test.to_string(), "private", "hashed", "public", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -654,7 +666,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None); + mock(path, test.to_string(), "private", "kzgcommit", "public", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -663,7 +675,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None); + mock(path, test.to_string(), "public", "private", "hashed", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -673,7 +685,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None); + mock(path, test.to_string(), "public", "private", "kzgcommit", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -682,7 +694,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None); + mock(path, test.to_string(), "public", "fixed", "hashed", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -692,7 +704,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None); + mock(path, test.to_string(), "public", "kzgcommit", "hashed", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -702,7 +714,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None); + mock(path, test.to_string(), "kzgcommit", "kzgcommit", "kzgcommit", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -712,7 +724,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None); + mock(path, test.to_string(), "hashed", "private", "hashed", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -722,7 +734,7 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); // needs an extra row for the large model - mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None); + mock(path, test.to_string(),"hashed", "hashed", "public", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -732,7 +744,7 @@ mod native_tests { let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); // needs an extra row for the large model - mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None); + mock(path, test.to_string(),"hashed", "hashed", "hashed", 1, "resources", None, 0.0); test_dir.close().unwrap(); } @@ -876,7 +888,7 @@ mod native_tests { crate::native_tests::init_binary(); let test_dir = TempDir::new(test).unwrap(); let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test); - mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None); + mock(path, test.to_string(), "private", "fixed", "public", 1, "resources", None, 0.0); test_dir.close().unwrap(); } }); @@ -1273,6 +1285,7 @@ mod native_tests { batch_size: usize, cal_target: &str, scales_to_use: Option>, + tolerance: f32, ) { gen_circuit_settings_and_witness( test_dir, @@ -1285,6 +1298,7 @@ mod native_tests { scales_to_use, 2, false, + tolerance, ); let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) @@ -1312,6 +1326,7 @@ mod native_tests { scales_to_use: Option>, num_inner_columns: usize, div_rebasing: bool, + tolerance: f32, ) { let mut args = vec![ "gen-settings".to_string(), @@ -1326,6 +1341,7 @@ mod native_tests { format!("--param-visibility={}", param_visibility), format!("--output-visibility={}", output_visibility), format!("--num-inner-cols={}", num_inner_columns), + format!("--tolerance={}", tolerance), ]; if div_rebasing { @@ -1425,6 +1441,7 @@ mod native_tests { None, 2, div_rebasing, + 0.0, ); println!( @@ -1684,6 +1701,7 @@ mod native_tests { scales_to_use, num_inner_columns, false, + 0.0, ); let settings_path = format!("{}/{}/settings.json", test_dir, example_name); @@ -1785,6 +1803,7 @@ mod native_tests { None, 2, false, + 0.0, ); let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR)) @@ -2061,6 +2080,7 @@ mod native_tests { Some(vec![4]), 1, false, + 0.0, ); let model_path = format!("{}/{}/network.compiled", test_dir, example_name);