Skip to content

Commit

Permalink
chore: update tract (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Apr 4, 2024
1 parent 5389012 commit 316a9a3
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 152 deletions.
258 changes: 133 additions & 125 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }


Expand Down
15 changes: 15 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1975,6 +1975,21 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
Ok(output)
}

pub(crate) fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
axes: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let squared = pow(config, region, values, 2)?;
let sum_squared = sum_axes(config, region, &[squared], axes)?;

let dividand: usize = values[0].len() / sum_squared.len();

let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
Ok(mean_squared)
}

/// expand the tensor to the given shape
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
Expand Down
39 changes: 34 additions & 5 deletions src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
fieldutils::{felt_to_i128, i128_to_felt},
tensor::{self, Tensor, TensorError},
};

Expand Down Expand Up @@ -62,6 +62,9 @@ pub enum PolyOp {
Sum {
axes: Vec<usize>,
},
MeanOfSquares {
axes: Vec<usize>,
},
Prod {
axes: Vec<usize>,
len_prod: usize,
Expand Down Expand Up @@ -105,10 +108,28 @@ impl<

fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::GatherND {
batch_dims,
indices,
} => format!(
"GATHERND (batch_dims={}, constant_idx{})",
batch_dims,
indices.is_some()
),
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
PolyOp::ScatterElements { dim, constant_idx } => format!(
"SCATTERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::ScatterND { constant_idx } => {
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
}
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
Expand Down Expand Up @@ -146,6 +167,10 @@ impl<
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::MeanOfSquares { axes } => {
let x = inputs[0].map(|x| felt_to_i128(x));
Ok(tensor::ops::nonlinearities::mean_of_squares_axes(&x, axes).map(i128_to_felt))
}
PolyOp::MultiBroadcastTo { shape } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch(
Expand Down Expand Up @@ -292,6 +317,9 @@ impl<
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
PolyOp::MeanOfSquares { axes } => {
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
Expand Down Expand Up @@ -404,6 +432,7 @@ impl<

fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {
Expand Down
32 changes: 16 additions & 16 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,20 @@ impl Model {
.collect();

for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);

let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
node.inputs()
.iter()
Expand All @@ -1211,25 +1225,11 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};

debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("dims: {:?}", node.out_dims());
debug!("output dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);

match &node {
NodeType::Node(n) => {
Expand Down
21 changes: 17 additions & 4 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
Expand Down Expand Up @@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
Expand Down Expand Up @@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
Expand Down Expand Up @@ -734,6 +734,19 @@ pub fn new_op_from_onnx(

SupportedOp::Linear(PolyOp::Sum { axes })
}
"Reduce<MeanOfSquares>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"mean of squares".to_string(),
)));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();

SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
}

"Max" => {
// Extract the max value
// first find the input that is a constant
Expand Down Expand Up @@ -1165,7 +1178,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
}
Expand Down
8 changes: 7 additions & 1 deletion src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
assert!(source < self.dims.len());
assert!(destination < self.dims.len());

let mut new_dims = self.dims.clone();
new_dims.remove(source);
new_dims.insert(destination, self.dims[source]);
Expand Down Expand Up @@ -965,6 +966,8 @@ impl<T: Clone + TensorType> Tensor<T> {
old_coord[source - 1] = *c;
} else if (i < source && source < destination)
|| (i < destination && source > destination)
|| (i > source && source > destination)
|| (i > destination && source < destination)
{
old_coord[i] = *c;
} else if i > source && source < destination {
Expand All @@ -977,7 +980,10 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
}
output.set(&coord, self.get(&old_coord));

let value = self.get(&old_coord);

output.set(&coord, value);
}

Ok(output)
Expand Down
26 changes: 26 additions & 0 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4404,6 +4404,32 @@ pub mod nonlinearities {
let sum = sum(a).unwrap();
const_div(&sum, (scale * a.len()) as f64)
}

/// Mean of squares axes
/// # Arguments
/// * `a` - Tensor
/// * `axis` - [usize]
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::tensor::ops::nonlinearities::mean_of_squares_axes;
/// let x = Tensor::<i128>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = mean_of_squares_axes(&x, &[1]);
/// let expected = Tensor::<i128>::new(
/// Some(&[78, 1]),
/// &[2, 1],
/// ).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn mean_of_squares_axes(a: &Tensor<i128>, axes: &[usize]) -> Tensor<i128> {
let square = a.map(|a_i| a_i * a_i);
let sum = sum_axes(&square, axes).unwrap();
let denominator = a.len() / sum.len();
const_div(&sum, denominator as f64)
}
}

/// Ops that return the transcript i.e intermediate calcs of an op
Expand Down

0 comments on commit 316a9a3

Please sign in to comment.