Skip to content

Commit

Permalink
feat: numerical accuracy reports post calibration (#647)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Jan 6, 2024
1 parent e97713f commit 22689cf
Show file tree
Hide file tree
Showing 7 changed files with 338 additions and 46 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "ee98004a2d8d7851da7b9fce954b2a7a7181eccb", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "4ee813d", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }


Expand Down
152 changes: 147 additions & 5 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::pfsys::{
};
use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::RunArgs;
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
Expand Down Expand Up @@ -64,6 +65,7 @@ use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use std::sync::OnceLock;
#[cfg(not(target_arch = "wasm32"))]
use std::time::Duration;
use tabled::Tabled;
use thiserror::Error;

#[cfg(not(target_arch = "wasm32"))]
Expand Down Expand Up @@ -625,6 +627,99 @@ pub(crate) fn init_bar(len: u64) -> ProgressBar {
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;

#[derive(Debug, Clone, Tabled)]
/// Accuracy tearsheet
pub struct AccuracyResults {
mean_error: f32,
median_error: f32,
max_error: f32,
min_error: f32,
mean_abs_error: f32,
median_abs_error: f32,
max_abs_error: f32,
min_abs_error: f32,
mean_squared_error: f32,
mean_percent_error: f32,
mean_abs_percent_error: f32,
}

impl AccuracyResults {
/// Create a new accuracy results struct
pub fn new(
mut original_preds: Vec<crate::tensor::Tensor<f32>>,
mut calibrated_preds: Vec<crate::tensor::Tensor<f32>>,
) -> Result<Self, Box<dyn Error>> {
let mut errors = vec![];
let mut abs_errors = vec![];
let mut squared_errors = vec![];
let mut percentage_errors = vec![];
let mut abs_percentage_errors = vec![];

for (original, calibrated) in original_preds.iter_mut().zip(calibrated_preds.iter_mut()) {
original.flatten();
calibrated.flatten();
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error =
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i].clone()))?;
let abs_percentage_error = percentage_error.map(|x| x.abs());

errors.extend(error.into_iter());
abs_errors.extend(abs_error.into_iter());
squared_errors.extend(squared_error.into_iter());
percentage_errors.extend(percentage_error.into_iter());
abs_percentage_errors.extend(abs_percentage_error.into_iter());
}

let mean_percent_error =
percentage_errors.iter().sum::<f32>() / percentage_errors.len() as f32;
let mean_abs_percent_error =
abs_percentage_errors.iter().sum::<f32>() / abs_percentage_errors.len() as f32;
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
let median_error = errors[errors.len() / 2];
let max_error = errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let min_error = errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();

let mean_abs_error = abs_errors.iter().sum::<f32>() / abs_errors.len() as f32;
let median_abs_error = abs_errors[abs_errors.len() / 2];
let max_abs_error = abs_errors
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();
let min_abs_error = abs_errors
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
.clone();

let mean_squared_error = squared_errors.iter().sum::<f32>() / squared_errors.len() as f32;

Ok(Self {
mean_error,
median_error,
max_error,
min_error,
mean_abs_error,
median_abs_error,
max_abs_error,
min_abs_error,
mean_squared_error,
mean_percent_error,
mean_abs_percent_error,
})
}
}

/// Calibrate the circuit parameters to a given a dataset
#[cfg(not(target_arch = "wasm32"))]
#[allow(trivial_casts)]
Expand All @@ -637,6 +732,9 @@ pub(crate) fn calibrate(
scales: Option<Vec<crate::Scale>>,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use std::collections::HashMap;
use tabled::Table;

let data = GraphData::from_path(data)?;
// load the pre-generated settings
let settings = GraphSettings::load(&settings_path)?;
Expand All @@ -655,6 +753,17 @@ pub(crate) fn calibrate(
#[cfg(unix)]
std::mem::drop(_r);

let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());

info!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
&settings.run_args,
&model_path,
&chunks,
model.graph.input_shapes()?,
)?;

let range = if let Some(scales) = scales {
scales
} else {
Expand All @@ -664,10 +773,6 @@ pub(crate) fn calibrate(
}
};

let chunks = data.split_into_batches(model.graph.input_shapes()?)?;

info!("num of calibration batches: {}", chunks.len());

let mut found_params: Vec<GraphSettings> = vec![];

let scale_rebase_multiplier = [1, 2, 10];
Expand Down Expand Up @@ -707,6 +812,8 @@ pub(crate) fn calibrate(
.map(|(a, b)| (*a, *b))
.collect::<Vec<((crate::Scale, crate::Scale), u32)>>();

let mut forward_pass_res = HashMap::new();

let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");

Expand All @@ -729,6 +836,9 @@ pub(crate) fn calibrate(
Err(_) => None,
};

let key = (input_scale, param_scale, scale_rebase_multiplier);
forward_pass_res.insert(key, vec![]);

let tasks = chunks
.iter()
.zip(run_args_iterable)
Expand Down Expand Up @@ -757,10 +867,16 @@ pub(crate) fn calibrate(
.load_graph_from_file_exclusively(&chunk)
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;

circuit
let forward_res = circuit
.calibrate(&data, max_logrows, lookup_safety_margin)
.map_err(|e| format!("failed to calibrate: {}", e))?;

// push result to the hashmap
forward_pass_res
.get_mut(&key)
.ok_or("key not found")?
.push(forward_res);

let settings = circuit.settings().clone();

let found_run_args = RunArgs {
Expand Down Expand Up @@ -899,6 +1015,32 @@ pub(crate) fn calibrate(
}
};

let outputs = forward_pass_res
.get(&(
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
))
.ok_or("no params found")?
.iter()
.map(|x| x.get_float_outputs(&best_params.model_output_scales))
.collect::<Vec<_>>();

let accuracy_res = AccuracyResults::new(
original_predictions.into_iter().flatten().collect(),
outputs.into_iter().flatten().collect(),
)?;

let tear_sheet_table = Table::new(vec![accuracy_res]);

println!(
"\n\n <------------- Numerical Fidelity Report (input_scale: {}, param_scale: {}, scale_input_multiplier: {}) ------------->\n\n{}\n\n",
best_params.run_args.input_scale,
best_params.run_args.param_scale,
best_params.run_args.scale_rebase_multiplier,
tear_sheet_table.to_string().as_str()
);

if matches!(target, CalibrationTarget::Resources { col_overflow: true }) {
let lookup_log_rows = ((best_params.run_args.lookup_range.1
- best_params.run_args.lookup_range.0) as f32)
Expand Down
41 changes: 38 additions & 3 deletions src/graph/input.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use super::quantize_float;
use super::GraphError;
use crate::circuit::InputType;
use crate::fieldutils::i128_to_felt;
#[cfg(not(target_arch = "wasm32"))]
Expand All @@ -18,11 +20,13 @@ use std::panic::UnwindSafe;
#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;

use super::quantize_float;
use super::GraphError;

type Decimals = u8;
type Call = String;
type RPCUrl = String;
Expand Down Expand Up @@ -445,6 +449,37 @@ pub struct GraphData {
impl UnwindSafe for GraphData {}

impl GraphData {
// not wasm
#[cfg(not(target_arch = "wasm32"))]
/// Convert the input data to tract data
pub fn to_tract_data(
&self,
shapes: &[Vec<usize>],
datum_types: &[tract_onnx::prelude::DatumType],
) -> Result<TVec<TValue>, Box<dyn std::error::Error>> {
let mut inputs = TVec::new();
match &self.input_data {
DataSource::File(data) => {
for (i, input) in data.iter().enumerate() {
if !input.is_empty() {
let dt = datum_types[i];
let input = input.iter().map(|e| e.to_float()).collect::<Vec<f64>>();
let tt = TractTensor::from_shape(&shapes[i], &input)?;
let tt = tt.cast_to_dt(dt)?;
inputs.push(tt.into_owned().into());
}
}
}
_ => {
return Err(Box::new(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
)))
}
}
Ok(inputs)
}

///
pub fn new(input_data: DataSource) -> Self {
GraphData {
Expand Down
19 changes: 17 additions & 2 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::table::{Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::RunArgs;
Expand Down Expand Up @@ -156,6 +157,19 @@ pub struct GraphWitness {
}

impl GraphWitness {
///
pub fn get_float_outputs(&self, scales: &[crate::Scale]) -> Vec<Tensor<f32>> {
self.outputs
.iter()
.enumerate()
.map(|(i, x)| {
x.iter()
.map(|y| (felt_to_f64(*y) / scale_to_multiplier(scales[i])) as f32)
.collect::<Tensor<f32>>()
})
.collect()
}

///
pub fn new(inputs: Vec<Vec<Fp>>, outputs: Vec<Vec<Fp>>) -> Self {
GraphWitness {
Expand Down Expand Up @@ -1109,9 +1123,10 @@ impl GraphCircuit {
input: &[Tensor<Fp>],
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let res = self.forward(&mut input.to_vec(), None, None)?;
self.calc_min_logrows(&res, max_logrows, lookup_safety_margin)
self.calc_min_logrows(&res, max_logrows, lookup_safety_margin)?;
Ok(res)
}

/// Runs the forward pass of the model / graph of computations and any associated hashing.
Expand Down
Loading

0 comments on commit 22689cf

Please sign in to comment.