Skip to content

Commit

Permalink
Update execute.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Dec 5, 2023
1 parent 8ce620b commit 60e5e2a
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::pfsys::{
};
use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::RunArgs;
#[cfg(not(target_arch = "wasm32"))]
use ethers::types::H160;
Expand Down Expand Up @@ -605,6 +606,8 @@ pub struct AccuracyResults {
max_abs_error: f32,
min_abs_error: f32,
mean_squared_error: f32,
mean_percent_error: f32,
mean_abs_percent_error: f32,
}

impl AccuracyResults {
Expand All @@ -616,19 +619,30 @@ impl AccuracyResults {
let mut errors = vec![];
let mut abs_errors = vec![];
let mut squared_errors = vec![];
let mut percentage_errors = vec![];
let mut abs_percentage_errors = vec![];

for (original, calibrated) in original_preds.iter_mut().zip(calibrated_preds.iter_mut()) {
original.flatten();
calibrated.flatten();
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error =
error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i].clone()))?;
let abs_percentage_error = percentage_error.map(|x| x.abs());

errors.extend(error.into_iter());
abs_errors.extend(abs_error.into_iter());
squared_errors.extend(squared_error.into_iter());
percentage_errors.extend(percentage_error.into_iter());
abs_percentage_errors.extend(abs_percentage_error.into_iter());
}

let mean_percent_error =
percentage_errors.iter().sum::<f32>() / percentage_errors.len() as f32;
let mean_abs_percent_error =
abs_percentage_errors.iter().sum::<f32>() / abs_percentage_errors.len() as f32;
let mean_error = errors.iter().sum::<f32>() / errors.len() as f32;
let median_error = errors[errors.len() / 2];
let max_error = errors
Expand Down Expand Up @@ -667,6 +681,8 @@ impl AccuracyResults {
max_abs_error,
min_abs_error,
mean_squared_error,
mean_percent_error,
mean_abs_percent_error,
})
}
}
Expand Down

0 comments on commit 60e5e2a

Please sign in to comment.