diff --git a/examples/onnx/1l_linear/gen.py b/examples/onnx/1l_linear/gen.py new file mode 100644 index 000000000..5abf4ff28 --- /dev/null +++ b/examples/onnx/1l_linear/gen.py @@ -0,0 +1,41 @@ +import random +import math +import numpy as np + +import torch +from torch import nn +import torch.nn.functional as F +import json + + +model = nn.Linear(1, 1) +x = torch.randn(1, 1) + +print(x) + +# Flips the neural net into inference mode +model.eval() +model.to('cpu') + +# Export the model +torch.onnx.export(model, # model being run + # model input (or a tuple for multiple inputs) + x, + # where to save the model (can be a file or file-like object) + "network.onnx", + export_params=True, # store the trained parameter weights inside the model file + opset_version=10, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=['input'], # the model's input names + output_names=['output'], # the model's output names + dynamic_axes={'input': {0: 'batch_size'}, # variable length axes + 'output': {0: 'batch_size'}}) + +data_array = ((x).detach().numpy()).reshape([-1]).tolist() + +data_json = dict(input_data=[data_array]) + +print(data_json) + +# Serialize data into file: +json.dump(data_json, open("input.json", 'w')) diff --git a/examples/onnx/1l_linear/input.json b/examples/onnx/1l_linear/input.json new file mode 100644 index 000000000..9d34f9ddc --- /dev/null +++ b/examples/onnx/1l_linear/input.json @@ -0,0 +1 @@ +{"input_data": [[-0.13821937143802643]]} \ No newline at end of file diff --git a/examples/onnx/1l_linear/network.onnx b/examples/onnx/1l_linear/network.onnx new file mode 100644 index 000000000..c3fd5f90e Binary files /dev/null and b/examples/onnx/1l_linear/network.onnx differ diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 0f9e79c98..efbba9e6f 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -163,17 +163,6 @@ pub fn einsum( let output_eq = equation.next().unwrap(); let inputs_eq = inputs_eq.split(',').collect::>(); - for (i, input) in inputs.iter_mut().enumerate() { - if input.dims().len() != inputs_eq[i].len() - && input.dims().len() == 1 - && inputs_eq[i].len() == 2 - { - input.reshape(&[1, input.dims()[0]])?; - } else if input.dims().len() != inputs_eq[i].len() { - return Err(Box::new(TensorError::DimMismatch("einsum".to_string()))); - } - } - // Check that the number of inputs matches the number of inputs in the equation if inputs.len() != inputs_eq.len() { return Err(Box::new(TensorError::DimMismatch("einsum".to_string()))); @@ -663,7 +652,7 @@ pub fn one_hot_axis( let inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| -> ValTensor { let inp = input_inner[i].clone(); let tensor = Tensor::new(Some(&[inp.clone()]), &[1]).unwrap(); - + one_hot(config, region, &[tensor.into()], num_classes).unwrap() }; @@ -732,7 +721,8 @@ pub fn gather( // Calculate the output tensor size let input_dims = input.dims(); let mut output_size = input_dims.to_vec(); - if index.dims().is_empty() { + if index.is_singleton() { + assert_eq!(input_dims[dim], 1); output_size.remove(dim); input.reshape(&output_size)?; return Ok(input); diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 82d05c050..882d75fd7 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -102,7 +102,7 @@ fn extract_tensor_value( let dims = input.shape().to_vec(); let mut const_value: Tensor; - if dims.is_empty() { + if dims.is_empty() && input.len() == 0 { const_value = Tensor::::new(None, &dims)?; return Ok(const_value); } @@ -248,7 +248,7 @@ pub fn new_op_from_onnx( op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather { dim: axis, constant_idx: Some(c.raw_values.map(|x| x as usize)), - }) + }); } // } @@ -1170,7 +1170,6 @@ pub fn quantize_tensor( )?)) })?; - value.reshape(const_value.dims()); value.set_scale(scale); value.set_visibility(visibility); Ok(value) diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 38eae6d9d..267545507 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -450,6 +450,8 @@ impl Tensor { pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result { let total_dims: usize = if !dims.is_empty() { dims.iter().product() + } else if let Some(_) = values { + 1 } else { 0 }; @@ -496,15 +498,16 @@ impl Tensor { /// Returns the number of elements in the tensor. pub fn len(&self) -> usize { - if !self.dims().is_empty() && (self.dims() != [0]) { - self.dims().iter().product::() - } else { - 0 - } + self.dims().iter().product::() } /// Checks if the number of elements in tensor is 0. pub fn is_empty(&self) -> bool { - self.dims().iter().product::() == 0 + self.inner.len() == 0 + } + + /// Checks if the number of elements in tensor is 1 but with an empty dimension (this is for onnx compatibility). + pub fn is_singleton(&self) -> bool { + self.dims().is_empty() && self.len() == 1 } /// Set one single value on the tensor. @@ -599,11 +602,11 @@ impl Tensor { where T: Send + Sync, { + if indices.is_empty() { + return Ok(self.clone()); + } if self.dims.len() < indices.len() { return Err(TensorError::DimError); - } else if indices.is_empty() { - // else if slice is empty, return empty tensor - return Ok(Tensor::new(None, &[]).unwrap()); } else if indices.iter().map(|x| x.end - x.start).collect::>() == self.dims { // else if slice is the same as dims, return self return Ok(self.clone()); @@ -768,7 +771,7 @@ impl Tensor { // in onnx parlance this corresponds to converting a tensor to a single element if new_dims.is_empty() { assert!(self.len() == 1 || self.is_empty()); - self.flatten(); + self.dims = vec![]; } else { let product = if new_dims != [0] { new_dims.iter().product::() diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index a7e265e3e..487123121 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -740,6 +740,7 @@ pub fn einsum< slice.push(0..inputs[idx].dims()[i]); } } + // Get the slice of the input tensor inputs[idx].get_slice(&slice).unwrap() }) @@ -1152,7 +1153,8 @@ pub fn gather( // Calculate the output tensor size let mut output_size = input.dims().to_vec(); // Reshape the output tensor - if index.is_empty() { + if index.is_singleton() { + assert_eq!(output_size[dim], 1); output_size.remove(dim); let mut input = input.clone(); input.reshape(&output_size); diff --git a/src/tensor/val.rs b/src/tensor/val.rs index 52442fbd1..d356420c5 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -382,6 +382,14 @@ impl ValTensor { Ok(res) } + /// Calls is_singleton on the inner tensor. + pub fn is_singleton(&self) -> bool { + match self { + ValTensor::Value { inner, .. } => inner.is_singleton(), + ValTensor::Instance { .. } => false, + } + } + /// Calls `int_evals` on the inner tensor. pub fn get_int_evals(&self) -> Result, Box> { // finally convert to vector of integers diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index e7131da6b..ad139d936 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -188,7 +188,7 @@ mod native_tests { "1l_prelu", ]; - const TESTS: [&str; 59] = [ + const TESTS: [&str; 60] = [ "1l_mlp", "1l_slice", "1l_concat", @@ -251,6 +251,7 @@ mod native_tests { "xgboost_reg", "1l_powf", "scatter_elements", + "1l_linear", //60 ]; const WASM_TESTS: [&str; 48] = [ @@ -480,7 +481,7 @@ mod native_tests { - seq!(N in 0..=58 { + seq!(N in 0..=59 { #(#[test_case(TESTS[N])])* fn model_serialization_(test: &str) { let test_dir = TempDir::new(test).unwrap();