Skip to content

Commit

Permalink
fix: slice for empty dim tensors (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Oct 4, 2023
1 parent 7ee8dfd commit c3aeae4
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 29 deletions.
41 changes: 41 additions & 0 deletions examples/onnx/1l_linear/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import random
import math
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
import json


model = nn.Linear(1, 1)
x = torch.randn(1, 1)

print(x)

# Flips the neural net into inference mode
model.eval()
model.to('cpu')

# Export the model
torch.onnx.export(model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data_json = dict(input_data=[data_array])

print(data_json)

# Serialize data into file:
json.dump(data_json, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/1l_linear/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[-0.13821937143802643]]}
Binary file added examples/onnx/1l_linear/network.onnx
Binary file not shown.
16 changes: 3 additions & 13 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,6 @@ pub fn einsum<F: PrimeField + TensorType + PartialOrd>(
let output_eq = equation.next().unwrap();
let inputs_eq = inputs_eq.split(',').collect::<Vec<_>>();

for (i, input) in inputs.iter_mut().enumerate() {
if input.dims().len() != inputs_eq[i].len()
&& input.dims().len() == 1
&& inputs_eq[i].len() == 2
{
input.reshape(&[1, input.dims()[0]])?;
} else if input.dims().len() != inputs_eq[i].len() {
return Err(Box::new(TensorError::DimMismatch("einsum".to_string())));
}
}

// Check that the number of inputs matches the number of inputs in the equation
if inputs.len() != inputs_eq.len() {
return Err(Box::new(TensorError::DimMismatch("einsum".to_string())));
Expand Down Expand Up @@ -663,7 +652,7 @@ pub fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| -> ValTensor<F> {
let inp = input_inner[i].clone();
let tensor = Tensor::new(Some(&[inp.clone()]), &[1]).unwrap();

one_hot(config, region, &[tensor.into()], num_classes).unwrap()
};

Expand Down Expand Up @@ -732,7 +721,8 @@ pub fn gather<F: PrimeField + TensorType + PartialOrd>(
// Calculate the output tensor size
let input_dims = input.dims();
let mut output_size = input_dims.to_vec();
if index.dims().is_empty() {
if index.is_singleton() {
assert_eq!(input_dims[dim], 1);
output_size.remove(dim);
input.reshape(&output_size)?;
return Ok(input);
Expand Down
5 changes: 2 additions & 3 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ fn extract_tensor_value(
let dims = input.shape().to_vec();

let mut const_value: Tensor<f32>;
if dims.is_empty() {
if dims.is_empty() && input.len() == 0 {
const_value = Tensor::<f32>::new(None, &dims)?;
return Ok(const_value);
}
Expand Down Expand Up @@ -248,7 +248,7 @@ pub fn new_op_from_onnx(
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
});
}
// }

Expand Down Expand Up @@ -1170,7 +1170,6 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
)?))
})?;

value.reshape(const_value.dims());
value.set_scale(scale);
value.set_visibility(visibility);
Ok(value)
Expand Down
23 changes: 13 additions & 10 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
let total_dims: usize = if !dims.is_empty() {
dims.iter().product()
} else if let Some(_) = values {
1
} else {
0
};
Expand Down Expand Up @@ -496,15 +498,16 @@ impl<T: Clone + TensorType> Tensor<T> {

/// Returns the number of elements in the tensor.
pub fn len(&self) -> usize {
if !self.dims().is_empty() && (self.dims() != [0]) {
self.dims().iter().product::<usize>()
} else {
0
}
self.dims().iter().product::<usize>()
}
/// Checks if the number of elements in tensor is 0.
pub fn is_empty(&self) -> bool {
self.dims().iter().product::<usize>() == 0
self.inner.len() == 0
}

/// Checks if the number of elements in tensor is 1 but with an empty dimension (this is for onnx compatibility).
pub fn is_singleton(&self) -> bool {
self.dims().is_empty() && self.len() == 1
}

/// Set one single value on the tensor.
Expand Down Expand Up @@ -599,11 +602,11 @@ impl<T: Clone + TensorType> Tensor<T> {
where
T: Send + Sync,
{
if indices.is_empty() {
return Ok(self.clone());
}
if self.dims.len() < indices.len() {
return Err(TensorError::DimError);
} else if indices.is_empty() {
// else if slice is empty, return empty tensor
return Ok(Tensor::new(None, &[]).unwrap());
} else if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims {
// else if slice is the same as dims, return self
return Ok(self.clone());
Expand Down Expand Up @@ -768,7 +771,7 @@ impl<T: Clone + TensorType> Tensor<T> {
// in onnx parlance this corresponds to converting a tensor to a single element
if new_dims.is_empty() {
assert!(self.len() == 1 || self.is_empty());
self.flatten();
self.dims = vec![];
} else {
let product = if new_dims != [0] {
new_dims.iter().product::<usize>()
Expand Down
4 changes: 3 additions & 1 deletion src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ pub fn einsum<
slice.push(0..inputs[idx].dims()[i]);
}
}

// Get the slice of the input tensor
inputs[idx].get_slice(&slice).unwrap()
})
Expand Down Expand Up @@ -1152,7 +1153,8 @@ pub fn gather<T: TensorType + Send + Sync>(
// Calculate the output tensor size
let mut output_size = input.dims().to_vec();
// Reshape the output tensor
if index.is_empty() {
if index.is_singleton() {
assert_eq!(output_size[dim], 1);
output_size.remove(dim);
let mut input = input.clone();
input.reshape(&output_size);
Expand Down
8 changes: 8 additions & 0 deletions src/tensor/val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,14 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
Ok(res)
}

/// Calls is_singleton on the inner tensor.
pub fn is_singleton(&self) -> bool {
match self {
ValTensor::Value { inner, .. } => inner.is_singleton(),
ValTensor::Instance { .. } => false,
}
}

/// Calls `int_evals` on the inner tensor.
pub fn get_int_evals(&self) -> Result<Tensor<i128>, Box<dyn Error>> {
// finally convert to vector of integers
Expand Down
5 changes: 3 additions & 2 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ mod native_tests {
"1l_prelu",
];

const TESTS: [&str; 59] = [
const TESTS: [&str; 60] = [
"1l_mlp",
"1l_slice",
"1l_concat",
Expand Down Expand Up @@ -251,6 +251,7 @@ mod native_tests {
"xgboost_reg",
"1l_powf",
"scatter_elements",
"1l_linear", //60
];

const WASM_TESTS: [&str; 48] = [
Expand Down Expand Up @@ -480,7 +481,7 @@ mod native_tests {



seq!(N in 0..=58 {
seq!(N in 0..=59 {
#(#[test_case(TESTS[N])])*
fn model_serialization_(test: &str) {
let test_dir = TempDir::new(test).unwrap();
Expand Down

0 comments on commit c3aeae4

Please sign in to comment.