Skip to content

Commit

Permalink
feat: scatter elements argument (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Sep 27, 2023
1 parent 4fbde29 commit 723b054
Show file tree
Hide file tree
Showing 10 changed files with 363 additions and 16 deletions.
46 changes: 46 additions & 0 deletions examples/onnx/scatter_elements/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from torch import nn
import torch
import json
import numpy as np


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()

def forward(self, w, x, src):
# scatter_elements
return w.scatter(2, x, src)


circuit = MyModel()


w = torch.rand(1, 15, 18)
src = torch.rand(1, 15, 2)
x = torch.randint(0, 15, (1, 15, 2))

torch.onnx.export(circuit, (w, x, src), "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input', 'input1', 'input2'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'input1': {0: 'batch_size'},
'input2': {0: 'batch_size'},
'output': {0: 'batch_size'}})


d = ((w).detach().numpy()).reshape([-1]).tolist()
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
d2 = ((src).detach().numpy()).reshape([-1]).tolist()


data = dict(
input_data=[d, d1, d2],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/scatter_elements/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[0.4006722569465637, 0.2608814835548401, 0.018605709075927734, 0.2076982855796814, 0.09738868474960327, 0.7288642525672913, 0.5645512342453003, 0.1282898187637329, 0.4371366500854492, 0.06852233409881592, 0.2802664637565613, 0.13553869724273682, 0.34616756439208984, 0.07160574197769165, 0.003862321376800537, 0.48938989639282227, 0.44164425134658813, 0.5328972935676575, 0.07106280326843262, 0.09725075960159302, 0.2589840292930603, 0.23286551237106323, 0.3244437575340271, 0.6505903005599976, 0.938034176826477, 0.26783543825149536, 0.786094069480896, 0.18789011240005493, 0.606765627861023, 0.19948190450668335, 0.3781137466430664, 0.8510081768035889, 0.13566946983337402, 0.9276034832000732, 0.5269846320152283, 0.3013773560523987, 0.5438669919967651, 0.12605935335159302, 0.4066329002380371, 0.24718689918518066, 0.4687347412109375, 0.23328912258148193, 0.8731061220169067, 0.3781728744506836, 0.5772581696510315, 0.8565052151679993, 0.7540851831436157, 0.779710590839386, 0.2353827953338623, 0.5320121645927429, 0.5085064172744751, 0.6413846015930176, 0.0774533748626709, 0.43543362617492676, 0.3591287136077881, 0.0804855227470398, 0.24766361713409424, 0.29354000091552734, 0.424951434135437, 0.43333709239959717, 0.40345197916030884, 0.8565409183502197, 0.06852656602859497, 0.04801344871520996, 0.1688656210899353, 0.9150650501251221, 0.3046884536743164, 0.380337119102478, 0.699865996837616, 0.7464215755462646, 0.390849232673645, 0.8992515802383423, 0.612838089466095, 0.4444986581802368, 0.9558160901069641, 0.7544505000114441, 0.394461989402771, 0.8434802293777466, 0.5277895331382751, 0.20150727033615112, 0.8635226488113403, 0.588026762008667, 0.260425865650177, 0.11915439367294312, 0.630763053894043, 0.35954052209854126, 0.6132745742797852, 0.5505245923995972, 0.557021975517273, 0.6679773926734924, 0.5876642465591431, 0.6560839414596558, 0.9582574963569641, 0.3779008984565735, 0.9340549111366272, 0.18399536609649658, 0.25960028171539307, 0.026599764823913574, 0.32735109329223633, 0.904168963432312, 0.14129406213760376, 0.82865971326828, 0.1153140664100647, 0.32517731189727783, 0.5030156970024109, 0.751384437084198, 0.6885764598846436, 0.34171855449676514, 0.6160555481910706, 0.16125458478927612, 0.9489563703536987, 0.1679406762123108, 0.20235764980316162, 0.881341278553009, 0.5121568441390991, 0.92611163854599, 0.4272810220718384, 0.2771685719490051, 0.9968149065971375, 0.15203642845153809, 0.4641759395599365, 0.413804292678833, 0.651200532913208, 0.32379841804504395, 0.6092511415481567, 0.8478893637657166, 0.9164987802505493, 0.7830571532249451, 0.8511074185371399, 0.15598517656326294, 0.37074369192123413, 0.4150542616844177, 0.5372101068496704, 0.14770209789276123, 0.23627835512161255, 0.06493926048278809, 0.09793734550476074, 0.8952625393867493, 0.3102150559425354, 0.43397819995880127, 0.5844067931175232, 0.3500853180885315, 0.22773021459579468, 0.05151098966598511, 0.6001721620559692, 0.3343912363052368, 0.46624046564102173, 0.549772322177887, 0.7372040748596191, 0.042217373847961426, 0.49783337116241455, 0.6191272735595703, 0.10424506664276123, 0.25230395793914795, 0.6359071135520935, 0.2743602991104126, 0.9432371258735657, 0.9513012170791626, 0.03317856788635254, 0.6779598593711853, 0.16297829151153564, 0.7508324384689331, 0.45239394903182983, 0.4571657180786133, 0.2818666100502014, 0.17718660831451416, 0.3443911671638489, 0.9507452845573425, 0.9657842516899109, 0.5738788843154907, 0.1653851866722107, 0.9145194292068481, 0.3741937279701233, 0.05001652240753174, 0.984059751033783, 0.6154639720916748, 0.6324515342712402, 0.08717000484466553, 0.894913911819458, 0.15595513582229614, 0.7157255411148071, 0.7270150780677795, 0.8960562944412231, 0.8680596947669983, 0.4046359062194824, 0.13201695680618286, 0.7696574926376343, 0.6563123464584351, 0.4042320251464844, 0.7971198558807373, 0.3859425187110901, 0.1174202561378479, 0.689612865447998, 0.3647807240486145, 0.4592200517654419, 0.05106186866760254, 0.6523975133895874, 0.2841695547103882, 0.07473695278167725, 0.5434084534645081, 0.20653265714645386, 0.31296950578689575, 0.3648838400840759, 0.5561838150024414, 0.34996211528778076, 0.4311484098434448, 0.7667861580848694, 0.8519712686538696, 0.9926041960716248, 0.11216002702713013, 0.632348895072937, 0.12260669469833374, 0.4656309485435486, 0.9705881476402283, 0.21109485626220703, 0.08676731586456299, 0.16977214813232422, 0.16263240575790405, 0.685645580291748, 0.14296847581863403, 0.6883099675178528, 0.046499550342559814, 0.026215791702270508, 0.2830098867416382, 0.5826435685157776, 0.5168894529342651, 0.21555876731872559, 0.6396673917770386, 0.5840371251106262, 0.8724454641342163, 0.13193941116333008, 0.6795164942741394, 0.15844780206680298, 0.27272599935531616, 0.6294320225715637, 0.35720330476760864, 0.8047906160354614, 0.392769455909729, 0.12731897830963135, 0.0994114875793457, 0.06120210886001587, 0.1784428358078003, 0.35258960723876953, 0.24094289541244507, 0.49197083711624146, 0.5611958503723145, 0.9082142114639282, 0.010551989078521729, 0.22507983446121216, 0.007577002048492432, 0.27637380361557007, 0.8734598159790039, 0.623768150806427, 0.10755449533462524, 0.4494497776031494, 0.06208443641662598, 0.5997845530509949, 0.8712562322616577, 0.4752771854400635, 0.3971136808395386, 0.01836305856704712, 0.4282991290092468, 0.40080422163009644, 0.757233738899231, 0.17318272590637207, 0.42513030767440796, 0.06316602230072021, 0.5725610256195068, 0.14671063423156738, 0.8214869499206543], [2, 0, 13, 12, 1, 0, 6, 3, 0, 4, 10, 0, 9, 3, 14, 13, 12, 0, 0, 11, 8, 14, 0, 8, 13, 6, 14, 9, 2, 10], [0.3678196668624878, 0.5077707171440125, 0.7048060894012451, 0.824059784412384, 0.5817335247993469, 0.1481603980064392, 0.7114542722702026, 0.46050769090652466, 0.2019960880279541, 0.47866880893707275, 0.6982521414756775, 0.8433347344398499, 0.3427056074142456, 0.182439923286438, 0.37574678659439087, 0.93959641456604, 0.5769227147102356, 0.5917701125144958, 0.30993229150772095, 0.26509368419647217, 0.3919970393180847, 0.04614824056625366, 0.09493398666381836, 0.8279429078102112, 0.29439830780029297, 0.42815113067626953, 0.4152073264122009, 0.33196568489074707, 0.24249297380447388, 0.9800586700439453]]}
30 changes: 30 additions & 0 deletions examples/onnx/scatter_elements/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
pytorch2.0.1:�
O
input
input1
input2output/ScatterElements"ScatterElements*
axis� torch_jitZ%
input


batch_size

Z&
input1


batch_size

Z&
input2


batch_size

b&
output


batch_size

B
34 changes: 34 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ pub enum HybridOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
}

impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::ScatterElements { .. } => vec![0, 2],
_ => vec![],
}
}
Expand Down Expand Up @@ -176,6 +181,20 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
(res.clone(), inter_equals)
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
let src = inputs[2].clone().map(|x| felt_to_i128(x));
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
(res.clone(), vec![])
} else {
let idx = inputs[1].clone().map(|x| felt_to_i128(x));
let inter_equals: Vec<Tensor<i128>> =
vec![Tensor::from(0..x.dims()[*dim] as i128)];
let res = tensor::ops::scatter(&x, &idx.map(|x| x as usize), &src, *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::MaxPool2d {
padding,
stride,
Expand Down Expand Up @@ -253,6 +272,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim),
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
HybridOp::OneHot { dim, num_classes } => {
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
}
Expand Down Expand Up @@ -281,6 +301,19 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::MaxPool2d {
padding,
stride,
Expand Down Expand Up @@ -395,6 +428,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Gather { .. }
| HybridOp::OneHot { .. }
| HybridOp::GatherElements { .. }
| HybridOp::ScatterElements { .. }
| HybridOp::Equals { .. } => {
vec![LookupOp::KroneckerDelta]
}
Expand Down
135 changes: 125 additions & 10 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
input.flatten();

// assert we have a single index
// assert_eq!(index.dims().iter().product::<usize>(), 1);
assert_eq!(index.dims().iter().product::<usize>(), 1);
assert!(dim_indices.all_prev_assigned() || region.is_dummy());

let is_assigned = !input.any_unknowns() && !index.any_unknowns();
Expand Down Expand Up @@ -715,13 +715,19 @@ pub fn gather<F: PrimeField + TensorType + PartialOrd>(
let (mut input, mut index) = (values[0].clone(), values[1].clone());
index.flatten();

let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.inputs[1], &index)?;
assigned_len.push(index.len());
}

if !assigned_len.is_empty() {
region.increment(assigned_len.iter().max().unwrap().clone());
}
region.increment(std::cmp::max(input.len(), index.len()));

// Calculate the output tensor size
let input_dims = input.dims();
Expand Down Expand Up @@ -858,6 +864,115 @@ pub fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}

/// Gather accumulated layout
pub fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());

assert_eq!(input.dims().len(), index.dims().len());

let mut assigned_len = vec![];

if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.inputs[1], &index)?;
assigned_len.push(index.len());
}
if !src.all_prev_assigned() {
src = region.assign(&config.output, &src)?;
assigned_len.push(src.len());
}

if !assigned_len.is_empty() {
region.increment(*assigned_len.iter().max().unwrap());
}

// Calculate the output tensor size
let input_dim = input.dims()[dim];
let output_size = index.dims().to_vec();

// these will be assigned as constants
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.inputs[1], &indices.into())?;
region.increment(indices.len());

// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let mut unit = Tensor::from(vec![F::from(1)].into_iter());
unit.set_visibility(&crate::graph::Visibility::Fixed);
let unit: ValTensor<F> = unit.into();
region.assign(&config.inputs[1], &unit).unwrap();
region.increment(1);

let mut output = Tensor::new(None, &output_size)?;

let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| -> () {
let coord = cartesian_coord[i].clone();
let index_val = index.get_inner_tensor().unwrap().get(&coord);

let src_val = src.get_inner_tensor().unwrap().get(&coord);
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();

let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dim;

let mut sliced_input = input.get_slice(&slice).unwrap();
sliced_input.flatten();

let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();

let mask = equals(config, region, &[index_valtensor, indices.clone()]).unwrap();

let one_minus_mask =
pairwise(config, region, &[unit.clone(), mask.clone()], BaseOp::Sub).unwrap();

let pairwise_prod = pairwise(config, region, &[src_valtensor, mask], BaseOp::Mult).unwrap();
let pairwise_prod_2 = pairwise(
config,
region,
&[sliced_input, one_minus_mask],
BaseOp::Mult,
)
.unwrap();

let res = pairwise(
config,
region,
&[pairwise_prod, pairwise_prod_2],
BaseOp::Add,
)
.unwrap();

let input_cartesian_coord = slice.into_iter().multi_cartesian_product();

let mutable_input_inner = input.get_inner_tensor_mut().unwrap();

for (i, r) in res.get_inner_tensor().unwrap().iter().enumerate() {
let coord = input_cartesian_coord.clone().nth(i).unwrap();
*mutable_input_inner.get_mut(&coord) = r.clone();
}
};

output.iter_mut().enumerate().for_each(|(i, o)| {
*o = inner_loop_function(i, region);
});

Ok(input)
}

/// sum accumulated layout
pub fn sum<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down Expand Up @@ -1750,20 +1865,21 @@ pub fn conv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::m

// we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them
// 1. assign the kernel
let mut assigned_kernel_len = 0;
let mut assigned_len = vec![];

if !kernel.all_prev_assigned() {
kernel = region.assign(&config.inputs[0], &kernel)?;
assigned_kernel_len = kernel.len();
assigned_len.push(kernel.len());
}
// 2. assign the image
let mut assigned_image_len = 0;
if !image.all_prev_assigned() {
image = region.assign(&config.inputs[1], &image)?;
assigned_image_len = image.len();
assigned_len.push(image.len());
}

// increment the region
region.increment(std::cmp::max(assigned_kernel_len, assigned_image_len));
if !assigned_len.is_empty() {
region.increment(*assigned_len.iter().max().unwrap());
}

let og_dims = image.dims().to_vec();

Expand Down Expand Up @@ -2446,14 +2562,13 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(

if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
region.increment(input.len());
}

if input.dims().len() == 1 {
return op(config, region, &[input]);
}

region.increment(input.len());

// Calculate the output tensor size
let input_dims = input.dims();

Expand Down
Loading

0 comments on commit 723b054

Please sign in to comment.