Skip to content

Commit

Permalink
feat: onehot op
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Sep 10, 2023
1 parent 8ad1e82 commit 741ffaf
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 11 deletions.
14 changes: 7 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ tokio = { version = "1.26.0", default_features = false, features = ["macros", "
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "561614519e6cb49eea4d88dcee3b880f127813cb", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "ca451330931687f00bccd30a1e2b5ec2fcccdcc9", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }


Expand Down
50 changes: 50 additions & 0 deletions examples/onnx/hummingbird_decision_tree/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Train a model.
import json
import onnxruntime as rt
from skl2onnx import to_onnx
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier as De
from hummingbird.ml import convert
import torch

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = De()
clr.fit(X_train, y_train)

torch_model = convert(clr, "pytorch").model


# Convert into ONNX format.
# export to onnx format

# Input to the model
shape = X_train.shape[1:]
x = torch.rand(1, *shape, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})

d = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes=[shape],
input_data=[d],
output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/hummingbird_decision_tree/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_shapes": [[4]], "input_data": [[0.9813985824584961, 0.793540358543396, 0.548916757106781, 0.6483156681060791]], "output_data": [[0], [1.0, 0.0, 0.0]]}
Binary file not shown.
14 changes: 14 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ pub enum HybridOp {
dim: usize,
k: usize,
},
OneHot {
dim: usize,
num_classes: usize,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
Expand Down Expand Up @@ -129,6 +133,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
(res.clone(), inter_equals)
}
}
HybridOp::OneHot { dim, num_classes } => {
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
(res.clone(), vec![])
}
HybridOp::TopK { dim, k } => {
let res = tensor::ops::topk_axes(&x, *k, *dim)?;

Expand Down Expand Up @@ -233,6 +241,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Gather { .. } => "GATHER",
HybridOp::TopK { .. } => "TOPK",
HybridOp::GatherElements { .. } => "GATHERELEMENTS",
HybridOp::OneHot { .. } => "ONEHOT",
};
name.into()
}
Expand Down Expand Up @@ -303,6 +312,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::TopK { dim, k } => {
layouts::topk_axes(config, region, values[..].try_into()?, *k, *dim)?
}
HybridOp::OneHot { dim, num_classes } => {
layouts::one_hot_axis(config, region, values[..].try_into()?, *num_classes, *dim)?
}
}))
}

Expand All @@ -320,6 +332,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::Less { .. }
| HybridOp::LessEqual { .. }
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
_ => in_scales[0],
Expand Down Expand Up @@ -359,6 +372,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::Less { .. }
| HybridOp::Equals
| HybridOp::Gather { .. }
| HybridOp::OneHot { .. }
| HybridOp::TopK { .. }
| HybridOp::GatherElements { .. } => {
vec![LookupOp::GreaterThan {
Expand Down
122 changes: 122 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,128 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
Ok(assigned_output)
}

fn one_hot<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
num_classes: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assert values is flat
assert_eq!(values[0].dims().len(), 1);
// assert its a single elelemnt
assert_eq!(values[0].len(), 1);
let input = values[0].clone();
let is_assigned = !input.any_unknowns();

let output: ValTensor<F> = if is_assigned {
let int_evals = input.get_int_evals()?;
let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?;
res.iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<_>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); num_classes]),
&[num_classes],
)?
}
.into();

let assigned_input = region.assign(&config.inputs[0], &input)?;

// now assert all elems are 0 or 1
let assigned_output = region.assign(&config.inputs[1], &output)?;
for i in 0..assigned_output.len() {
let (x, y) = config.output.cartesian_coord(region.offset() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x));
region.enable(selector, y)?;
}
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));

let sum = sum(config, region, &[assigned_output.clone()])?;
// assert sum is 1
let mut unit = Tensor::from(vec![F::from(1)].into_iter());
unit.set_visibility(crate::graph::Visibility::Public);
let unit = region.assign(&config.inputs[1], &unit.into())?;
region.assign(&config.output, &sum)?;

let (x, y) = config.output.cartesian_coord(region.offset());
let selector = config.selectors.get(&(BaseOp::Identity, x));
region.enable(selector, y)?;

region.increment(1);

let gathered = gather(
config,
region,
&[assigned_output.clone(), assigned_input.clone()],
0,
)?;

region.assign(&config.inputs[1], &unit)?;
region.assign(&config.output, &gathered)?;

let (x, y) = config.output.cartesian_coord(region.offset());
let selector = config.selectors.get(&(BaseOp::Identity, x));
region.enable(selector, y)?;

region.increment(assigned_input.len());

Ok(assigned_output)
}

/// One hot accumulated layout
pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
num_classes: usize,
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let input = values[0].clone();
let input_inner = input.get_inner_tensor()?;

let mut output_dims = values[0].dims().to_vec();
output_dims.insert(dim, num_classes);

let op_tensors = input_inner.enum_map(|_: usize, inp| {
let tensor = Tensor::new(Some(&[inp.clone()]), &[1]).unwrap();
let res = one_hot(config, region, &[tensor.into()], num_classes).map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;

Ok::<_, halo2_proofs::plonk::Error>(res)
})?;

// Allocate memory for the output tensor
let cartesian_coord = output_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let mut output = Tensor::<ValType<F>>::new(None, &output_dims)?;

output = output.enum_map(|i, _| {
let coord = cartesian_coord[i].clone();
let mut op_idx = coord.clone();
let coord_at_dims = vec![coord[dim]];
op_idx.remove(dim);

let op_tensor = op_tensors.get(&op_idx).get_inner_tensor().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;

let one_hot_val = op_tensor.get(&coord_at_dims).clone();

Ok::<_, halo2_proofs::plonk::Error>(one_hot_val)
})?;

Ok(output.into())
}

/// Gather accumulated layout
pub fn gather<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down
12 changes: 11 additions & 1 deletion src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(not(target_arch = "wasm32"))]
use tract_onnx::tract_core::ops::{
array::{Gather, GatherElements, Slice, Topk},
array::{Gather, GatherElements, OneHot, Slice, Topk},
change_axes::AxisOp,
cnn::DeconvUnary,
einsum::EinSum,
Expand Down Expand Up @@ -245,6 +245,16 @@ pub fn new_op_from_onnx(

SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::TopK { dim: axis, k })
}
"Onehot" => {
let op = load_op::<OneHot>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let num_classes = op.dim;

SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::OneHot {
dim: axis,
num_classes,
})
}
"GatherElements" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(
Expand Down
46 changes: 46 additions & 0 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,52 @@ pub fn intercalate_values<T: TensorType>(
Ok(output)
}

/// One hot encodes a tensor along a given axis.
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::one_hot;
/// let tensor = Tensor::<i128>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
/// let result = one_hot(&tensor, 5, 2).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0,
/// 0, 0, 1, 0, 0,
/// 0, 0, 0, 1, 0,
/// 0, 0, 0, 0, 1]), &[2, 2, 5]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn one_hot(
tensor: &Tensor<i128>,
num_classes: usize,
axis: usize,
) -> Result<Tensor<i128>, TensorError> {
let mut output_dims = tensor.dims().to_vec();
output_dims.insert(axis, num_classes);

let mut output: Tensor<i128> = Tensor::new(None, &output_dims)?;

let cartesian_coord = output
.dims()
.iter()
.map(|d| (0..*d))
.multi_cartesian_product()
.collect::<Vec<_>>();

output.iter_mut().enumerate().for_each(|(i, o)| {
let coord = &cartesian_coord[i];
let coord_axis = coord[axis];

let mut coord_without_axis = coord.clone();
coord_without_axis.remove(axis);

if coord_axis == tensor.get(&coord_without_axis) as usize {
*o = 1;
} else {
*o = 0;
}
});

Ok(output)
}

/// Performs a 2D deconvolution on the given input tensor.
/// # Examples
/// ```
Expand Down
Loading

0 comments on commit 741ffaf

Please sign in to comment.