Skip to content

Commit

Permalink
chore: broader onnx type to felt conversion support (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Sep 15, 2023
1 parent c57be0b commit 5df3703
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 88 deletions.
36 changes: 28 additions & 8 deletions examples/notebooks/mean_postgres.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"\n",
"await ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -239,7 +255,18 @@
"source": [
"# setup kzg params\n",
"params_path = os.path.join('kzg.params')\n",
"res = ezkl.gen_srs(params_path, 7)"
"\n",
"res = ezkl.get_srs(params_path, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_model(onnx_filename, compiled_filename, settings_filename)"
]
},
{
Expand All @@ -255,14 +282,7 @@
"outputs": [],
"source": [
"# generate settings\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")\n",
"ezkl.compile_model(onnx_filename, compiled_filename, settings_filename)\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
Expand Down
10 changes: 4 additions & 6 deletions examples/notebooks/random_forest.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"print(trees)\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
Expand All @@ -98,13 +96,13 @@
"\n",
"torch_rf = RandomForest(trees)\n",
"# assert predictions from torch are = to sklearn \n",
"\n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" print(torch_pred, sk_pred[0])\n",
" assert torch_pred[0].round() == sk_pred[0]\n",
"\n"
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
"print(\"num diffs\", sum(diffs))\n"
]
},
{
Expand Down
13 changes: 11 additions & 2 deletions examples/notebooks/variance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
"\n",
"\n",
"ezkl.gen_settings(onnx_filename, settings_filename)\n",
"ezkl.calibrate_settings(\n",
"await ezkl.calibrate_settings(\n",
" input_filename, onnx_filename, settings_filename, \"resources\")\n",
"ezkl.compile_model(onnx_filename, compiled_filename, settings_filename)\n",
"\n",
Expand Down Expand Up @@ -462,7 +462,16 @@
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions examples/onnx/linear_svc/input.json

Large diffs are not rendered by default.

Binary file added examples/onnx/linear_svc/network.onnx
Binary file not shown.
8 changes: 0 additions & 8 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}))
}

fn requires_specific_input_scales(&self) -> Vec<(usize, u32)> {
match self {
HybridOp::Gather { .. } | HybridOp::GatherElements { .. } => vec![(1, 0)],
HybridOp::OneHot { .. } => vec![(0, 0)],
_ => vec![],
}
}

fn out_scale(&self, in_scales: Vec<u32>) -> u32 {
match self {
HybridOp::Greater { .. }
Expand Down
5 changes: 0 additions & 5 deletions src/circuit/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
vec![]
}

/// Do any of the inputs to this op require specific input scales?
fn requires_specific_input_scales(&self) -> Vec<(usize, u32)> {
vec![]
}

/// Returns the lookups required by the operation.
fn required_lookups(&self) -> Vec<LookupOp> {
vec![]
Expand Down
40 changes: 0 additions & 40 deletions src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,19 +391,6 @@ impl Op<Fp> for SupportedOp {
}
}

fn requires_specific_input_scales(&self) -> Vec<(usize, u32)> {
match self {
SupportedOp::Linear(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Hybrid(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Input(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Constant(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Unknown(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Rescaled(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::requires_specific_input_scales(op),
}
}

fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
match self {
SupportedOp::Linear(op) => Box::new(op.clone()),
Expand Down Expand Up @@ -608,33 +595,6 @@ impl Node {
}
}

let inputs_at_specific_scales = opkind.requires_specific_input_scales();
// rescale the inputs if necessary to get consistent fixed points
for (input, scale) in inputs_at_specific_scales
.into_iter()
.filter(|(i, _)| !deleted_indices.contains(i))
{
let input_node = other_nodes.get_mut(&inputs[input].idx()).unwrap();
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
rescale_const_with_single_use(constant, in_scales.clone(), param_visibility)?;
input_node.replace_opkind(constant.clone_dyn().into());
let out_scale = input_opkind.out_scale(vec![]);
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
} else {
let scale_diff = in_scales[input] as i128 - scale as i128;
let rebased = if scale_diff > 0 {
RebaseScale::rebase(input_opkind.clone(), scale, in_scales[input], 1)
} else {
RebaseScale::rebase_up(input_opkind.clone(), scale, in_scales[input])
};
input_node.replace_opkind(rebased);
input_node.bump_scale(scale);
in_scales[input] = scale;
}
}

opkind = opkind.homogenous_rescale(in_scales.clone()).into();
let mut out_scale = opkind.out_scale(in_scales.clone());
opkind = RebaseScale::rebase(opkind, scales.input, out_scale, scales.rebase_multiplier);
Expand Down
77 changes: 68 additions & 9 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ fn extract_tensor_value(
}

match dt {
DatumType::F16 => {
let vec = input.as_slice::<tract_onnx::prelude::f16>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| (*x).into()).collect();
const_value = cast.into_iter().into();
}
DatumType::F32 => {
let vec = input.as_slice::<f32>()?.to_vec();
const_value = vec.into_iter().into();
Expand Down Expand Up @@ -139,6 +144,30 @@ fn extract_tensor_value(
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U8 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u8>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U16 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u16>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U32 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u32>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::U64 => {
// Generally a shape or hyperparam
let vec = input.as_slice::<u64>()?.to_vec();
let cast: Vec<f32> = vec.iter().map(|x| *x as f32).collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
DatumType::Bool => {
// Generally a shape or hyperparam
let vec = input.as_slice::<bool>()?.to_vec();
Expand All @@ -154,7 +183,7 @@ fn extract_tensor_value(
.collect();
const_value = Tensor::<f32>::new(Some(&cast), &dims)?;
}
_ => todo!(),
_ => todo!("unsupported type"),
}
const_value.reshape(&dims);

Expand Down Expand Up @@ -332,10 +361,19 @@ pub fn new_op_from_onnx(
// Raw values are always f32
let raw_value = extract_tensor_value(op.0)?;
// If bool or a tensor dimension then don't scale
let constant_scale = if dt == DatumType::Bool || dt == DatumType::TDim {
0
} else {
scales.params
let constant_scale = match dt {
DatumType::Bool
| DatumType::TDim
| DatumType::I64
| DatumType::I32
| DatumType::I16
| DatumType::I8
| DatumType::U8
| DatumType::U16
| DatumType::U32
| DatumType::U64 => 0,
DatumType::F16 | DatumType::F32 | DatumType::F64 => scales.params,
_ => todo!("unsupported type"),
};
// Quantize the raw value
let quantized_value =
Expand Down Expand Up @@ -588,7 +626,16 @@ pub fn new_op_from_onnx(
let (scale, datum_type) = match node.outputs[0].fact.datum_type {
DatumType::Bool => (0, InputType::Bool),
DatumType::TDim => (0, InputType::TDim),
_ => (scales.input, InputType::Num),
DatumType::I64
| DatumType::I32
| DatumType::I16
| DatumType::I8
| DatumType::U8
| DatumType::U16
| DatumType::U32
| DatumType::U64 => (0, InputType::Num),
DatumType::F16 | DatumType::F32 | DatumType::F64 => (scales.input, InputType::Num),
_ => todo!(),
};
SupportedOp::Input(crate::circuit::ops::Input { scale, datum_type })
}
Expand All @@ -600,12 +647,24 @@ pub fn new_op_from_onnx(
.flat_map(|x| x.out_scales())
.collect::<Vec<_>>();
assert_eq!(input_scales.len(), 1);

match dt {
DatumType::Bool => SupportedOp::Nonlinear(LookupOp::Div {
DatumType::Bool
| DatumType::TDim
| DatumType::I64
| DatumType::I32
| DatumType::I16
| DatumType::I8
| DatumType::U8
| DatumType::U16
| DatumType::U32
| DatumType::U64 => SupportedOp::Nonlinear(LookupOp::Div {
denom: crate::circuit::utils::F32(scale_to_multiplier(input_scales[0]) as f32),
}),
DatumType::String | DatumType::Blob => unimplemented!(),
_ => SupportedOp::Linear(PolyOp::Identity),
DatumType::F16 | DatumType::F32 | DatumType::F64 => {
SupportedOp::Linear(PolyOp::Identity)
}
_ => todo!("unsupported type"),
}
}
"Add" => SupportedOp::Linear(PolyOp::Add),
Expand Down
28 changes: 21 additions & 7 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,6 @@ struct PyRunArgs {
pub param_visibility: Visibility,
#[pyo3(get, set)]
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
pub allocated_constraints: Option<usize>,
}

/// default instantiation of PyRunArgs
Expand All @@ -271,14 +269,13 @@ impl PyRunArgs {
tolerance: 0.0,
input_scale: 7,
param_scale: 7,
scale_rebase_multiplier: 2,
scale_rebase_multiplier: 1,
bits: 16,
logrows: 17,
input_visibility: "public".into(),
output_visibility: "public".into(),
param_visibility: "private".into(),
input_visibility: Visibility::Private,
output_visibility: Visibility::Public,
param_visibility: Visibility::Private,
variables: vec![("batch_size".to_string(), 1)],
allocated_constraints: None,
}
}
}
Expand All @@ -301,6 +298,23 @@ impl From<PyRunArgs> for RunArgs {
}
}

impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
PyRunArgs {
tolerance: self.tolerance.val.into(),
input_scale: self.input_scale,
param_scale: self.param_scale,
scale_rebase_multiplier: self.scale_rebase_multiplier,
bits: self.bits,
logrows: self.logrows,
input_visibility: self.input_visibility,
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
}
}
}

/// Converts 4 u64s to a field element
#[pyfunction(signature = (
array,
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod native_tests {
"1l_prelu",
];

const TESTS: [&str; 53] = [
const TESTS: [&str; 54] = [
"1l_mlp",
"1l_slice",
"1l_concat",
Expand Down Expand Up @@ -226,6 +226,7 @@ mod native_tests {
"lightgbm",
"hummingbird_decision_tree",
"oh_decision_tree",
"linear_svc",
];

const TESTS_AGGR: [&str; 21] = [
Expand Down Expand Up @@ -395,8 +396,7 @@ mod native_tests {



seq!(N in 0..=52 {

seq!(N in 0..=53 {
#(#[test_case(TESTS[N])])*
fn model_serialization_(test: &str) {
let test_dir = TempDir::new(test).unwrap();
Expand Down

0 comments on commit 5df3703

Please sign in to comment.