Skip to content

Commit

Permalink
feat: downsample op (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Jul 6, 2023
1 parent 5114f98 commit 4ffbd68
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 15 deletions.
23 changes: 11 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ halo2curves = { git = 'https://github.com/privacy-scaling-explorations/halo2curv
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
plotters = { version = "0.3.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "0a661fe", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", branch= "fixes-1122", default_features = false, optional = true }
clap = { version = "4.3.3", features = ["derive"]}
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = ["float_roundtrip", "raw_value"], optional = true }
Expand Down
15 changes: 15 additions & 0 deletions examples/onnx/1l_downsample/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torch import nn
from ezkl import export


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = nn.Conv2d(3, 1, (1, 1), 2, 1)

def forward(self, x):
return self.layer(x)


circuit = Model()
export(circuit, input_shape=[3, 6, 6])
1 change: 1 addition & 0 deletions examples/onnx/1l_downsample/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_shapes": [[3, 6, 6]], "input_data": [[0.0013342977035790682, 0.09996837377548218, 0.04596591740846634, 0.028909338638186455, 0.0023150264751166105, 0.031060589477419853, 0.04478610306978226, 0.06076720356941223, 0.045172106474637985, 0.05557824298739433, 0.07805796712636948, 0.008817249909043312, 0.06720414012670517, 0.08955056220293045, 0.09216128289699554, 0.045477788895368576, 0.01003159862011671, 0.058520179241895676, 0.0695599615573883, 0.042656876146793365, 0.08834701031446457, 0.04972696304321289, 0.008722865954041481, 0.08152175694704056, 0.021779973059892654, 0.020261067897081375, 0.08429203927516937, 0.033855486661195755, 0.0801347866654396, 0.04853340983390808, 0.07470901310443878, 0.09944545477628708, 0.012083947658538818, 0.05083233118057251, 0.016968900337815285, 0.09431719779968262, 0.026137595996260643, 0.032507576048374176, 0.06173516437411308, 0.020426113158464432, 0.06253501772880554, 0.026606155559420586, 0.0072502256371080875, 0.06107022985816002, 0.07124727219343185, 0.06443677097558975, 0.009884399361908436, 0.05601707100868225, 0.04880037531256676, 0.04238104820251465, 0.011621475219726562, 0.050199996680021286, 0.07783322036266327, 0.008418447338044643, 0.06102946028113365, 0.07127648591995239, 0.05345148593187332, 0.05489174649119377, 0.048850901424884796, 0.06494833528995514, 0.0527624785900116, 0.03499723598361015, 0.04975031688809395, 0.08880099654197693, 0.09861764311790466, 0.08910749852657318, 0.09743963927030563, 0.05271735414862633, 0.08343677967786789, 0.02000223472714424, 0.08513142168521881, 0.034427475184202194, 0.0873868465423584, 0.010047852993011475, 0.026587819680571556, 0.08355271071195602, 0.027053965255618095, 0.03804783895611763, 0.04176938161253929, 0.04468965157866478, 0.05959935113787651, 0.08032093197107315, 0.03739238530397415, 0.027487248182296753, 0.06620214879512787, 0.0759638100862503, 0.053834449499845505, 0.09121784567832947, 0.04995650798082352, 0.015401131473481655, 0.019204294309020042, 0.0033545016776770353, 0.05057207867503166, 0.008860093541443348, 0.0005154967657290399, 0.003039717674255371, 0.021668827161192894, 0.009054434485733509, 0.04053889587521553, 0.054946959018707275, 0.017397576943039894, 0.024417413398623466, 0.06150643154978752, 0.06648053973913193, 0.024091368541121483, 0.014438122510910034, 0.05136483907699585, 0.015915799885988235]], "output_data": [[0.26444995403289795, 0.26444995403289795, 0.26444995403289795, 0.26444995403289795, 0.26444995403289795, 0.26322141289711, 0.2664680778980255, 0.2690267860889435, 0.26444995403289795, 0.2638533413410187, 0.26202982664108276, 0.25793612003326416, 0.26444995403289795, 0.2585963308811188, 0.2597063481807709, 0.25475919246673584]]}
26 changes: 26 additions & 0 deletions examples/onnx/1l_downsample/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
pytorch1.13.1:�
�
input
layer.weight

layer.biasoutput /layer/Conv"Conv*
dilations@@�*
group�*
kernel_shape@@�*
pads@@@@�*
strides@@� torch_jit*&B layer.weightJ �N��:�=�݂=*B
layer.biasJ�e�>Z)
input


batch_size


b*
output


batch_size


B
17 changes: 17 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,23 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}

/// Downsample layout
pub fn downsample<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
axis: &usize,
stride: &usize,
modulo: &usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let input = region.assign(&config.inputs[0], &values[0])?;
let processed_output =
tensor::ops::downsample(&input.get_inner_tensor()?, *axis, *stride, *modulo)?;
let output = region.assign(&config.output, &processed_output.into())?;
region.increment(input.len());
Ok(output)
}

/// Layout for range check.
pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down
17 changes: 17 additions & 0 deletions src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ pub enum PolyOp<F: PrimeField + TensorType + PartialOrd> {
padding: (usize, usize),
stride: (usize, usize),
},
Downsample {
axis: usize,
stride: usize,
modulo: usize,
},
DeConv {
kernel: ValTensor<F>,
bias: Option<ValTensor<F>>,
Expand Down Expand Up @@ -73,6 +78,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
}
fn as_string(&self) -> String {
let name = match &self {
PolyOp::Downsample { .. } => "DOWNSAMPLE",
PolyOp::Resize { .. } => "RESIZE",
PolyOp::Iff => "IFF",
PolyOp::Einsum { .. } => "EINSUM",
Expand Down Expand Up @@ -101,6 +107,11 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::Downsample {
axis,
stride,
modulo,
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
Expand Down Expand Up @@ -214,6 +225,11 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
let mut values = values.to_vec();

Ok(Some(match self {
PolyOp::Downsample {
axis,
stride,
modulo,
} => layouts::downsample(config, region, values[..].try_into()?, axis, stride, modulo)?,
PolyOp::Resize { scale_factor } => {
layouts::resize(config, region, values[..].try_into()?, scale_factor)?
}
Expand Down Expand Up @@ -312,6 +328,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {

fn out_scale(&self, in_scales: Vec<u32>, _g: u32) -> u32 {
match self {
PolyOp::Downsample { .. } => in_scales[0],
PolyOp::Resize { .. } => in_scales[0],
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {
Expand Down
18 changes: 18 additions & 0 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use tract_onnx::tract_core::ops::array::Gather;
use tract_onnx::tract_core::ops::array::Slice;
use tract_onnx::tract_core::ops::cnn::DeconvUnary;
use tract_onnx::tract_core::ops::einsum::EinSum;
use tract_onnx::tract_core::ops::Downsample;

use tract_onnx::tract_core::ops::element_wise::ElementWiseOp;

Expand Down Expand Up @@ -616,6 +617,23 @@ pub fn new_op_from_onnx(
stride: (stride_h, stride_w),
})
}
"Downsample" => {
let downsample_node: Downsample = match node.op().downcast_ref::<Downsample>() {
Some(b) => b.clone(),
None => {
return Err(Box::new(GraphError::OpMismatch(
idx,
"downsample".to_string(),
)));
}
};

Box::new(PolyOp::Downsample {
axis: downsample_node.axis,
stride: downsample_node.stride as usize,
modulo: downsample_node.modulo,
})
}

"Resize" => {
// this is a bit hacky, but we need to extract the resize node somehow
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
unsafe_code
)]
#![feature(lint_reasons)]
#![feature(int_roundings)]

//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
Expand Down
65 changes: 65 additions & 0 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,71 @@ pub fn sum<T: TensorType + Add<Output = T>>(a: &Tensor<T>) -> Result<Tensor<T>,
Tensor::new(Some(&[res]), &[1])
}

/// Downsamples a tensor along a dimension.
/// # Arguments
/// * `input` - Tensor
/// * `dim` - Dimension to downsample along
/// * `stride` - Stride to downsample by
/// * `modulo` - Modulo to downsample by
/// # Examples
/// ```
/// use ezkl_lib::tensor::Tensor;
/// use ezkl_lib::tensor::ops::downsample;
/// let x = Tensor::<i128>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
/// &[2, 3],
/// ).unwrap();
/// let result = downsample(&x, 0, 1, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[4, 5, 6]), &[1, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = downsample(&x, 1, 2, 0).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 3, 4, 6]), &[2, 2]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = downsample(&x, 1, 2, 1).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[2, 5]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = downsample(&x, 1, 2, 2).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[3, 6]), &[2, 1]).unwrap();
/// assert_eq!(result, expected);
pub fn downsample<T: TensorType>(
input: &Tensor<T>,
dim: usize,
stride: usize,
modulo: usize,
) -> Result<Tensor<T>, TensorError> {
let mut output_shape = input.dims().to_vec();
output_shape[dim] = (input.dims()[dim] - modulo).div_ceil(stride);
let mut output = Tensor::<T>::new(None, &output_shape)?;

assert!(modulo <= input.dims()[dim]);
// now downsample along axis dim offset by modulo
let indices = (0..output_shape.len())
.map(|i| {
if i == dim {
let mut index = vec![0; output_shape[i]];
for (i, idx) in index.iter_mut().enumerate() {
*idx = i * stride + modulo;
}
index
} else {
(0..output_shape[i]).collect_vec()
}
})
.multi_cartesian_product()
.collect::<Vec<_>>();

output
.iter_mut()
.zip(indices.iter())
.for_each(|(o, i)| *o = input.get(i.as_slice()));

Ok(output)
}

/// Gathers a tensor along a dimension.
/// # Arguments
/// * `input` - Tensor
Expand Down
5 changes: 3 additions & 2 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ mod native_tests {

const LARGE_TESTS: [&str; 3] = ["self_attention", "nanoGPT", "mobilenet"];

const TESTS: [&str; 34] = [
const TESTS: [&str; 35] = [
"1l_mlp",
"1l_slice",
"1l_concat",
Expand All @@ -174,6 +174,7 @@ mod native_tests {
"1l_gelu_noappx",
// "1l_gelu_tanh_appx",
"1l_relu",
"1l_downsample",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_fc",
Expand Down Expand Up @@ -293,7 +294,7 @@ mod native_tests {
}


seq!(N in 0..=33 {
seq!(N in 0..=34 {

#(#[test_case(TESTS[N])])*
fn render_circuit_(test: &str) {
Expand Down

0 comments on commit 4ffbd68

Please sign in to comment.