Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: scatter elements argument #503

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions examples/onnx/scatter_elements/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from torch import nn
import torch
import json
import numpy as np


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()

def forward(self, w, x, src):
# scatter_elements
return w.scatter(2, x, src)


circuit = MyModel()


w = torch.rand(1, 15, 18)
src = torch.rand(1, 15, 2)
x = torch.randint(0, 15, (1, 15, 2))

torch.onnx.export(circuit, (w, x, src), "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
# the model's input names
input_names=['input', 'input1', 'input2'],
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'input1': {0: 'batch_size'},
'input2': {0: 'batch_size'},
'output': {0: 'batch_size'}})


d = ((w).detach().numpy()).reshape([-1]).tolist()
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
d2 = ((src).detach().numpy()).reshape([-1]).tolist()


data = dict(
input_data=[d, d1, d2],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/scatter_elements/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[0.4006722569465637, 0.2608814835548401, 0.018605709075927734, 0.2076982855796814, 0.09738868474960327, 0.7288642525672913, 0.5645512342453003, 0.1282898187637329, 0.4371366500854492, 0.06852233409881592, 0.2802664637565613, 0.13553869724273682, 0.34616756439208984, 0.07160574197769165, 0.003862321376800537, 0.48938989639282227, 0.44164425134658813, 0.5328972935676575, 0.07106280326843262, 0.09725075960159302, 0.2589840292930603, 0.23286551237106323, 0.3244437575340271, 0.6505903005599976, 0.938034176826477, 0.26783543825149536, 0.786094069480896, 0.18789011240005493, 0.606765627861023, 0.19948190450668335, 0.3781137466430664, 0.8510081768035889, 0.13566946983337402, 0.9276034832000732, 0.5269846320152283, 0.3013773560523987, 0.5438669919967651, 0.12605935335159302, 0.4066329002380371, 0.24718689918518066, 0.4687347412109375, 0.23328912258148193, 0.8731061220169067, 0.3781728744506836, 0.5772581696510315, 0.8565052151679993, 0.7540851831436157, 0.779710590839386, 0.2353827953338623, 0.5320121645927429, 0.5085064172744751, 0.6413846015930176, 0.0774533748626709, 0.43543362617492676, 0.3591287136077881, 0.0804855227470398, 0.24766361713409424, 0.29354000091552734, 0.424951434135437, 0.43333709239959717, 0.40345197916030884, 0.8565409183502197, 0.06852656602859497, 0.04801344871520996, 0.1688656210899353, 0.9150650501251221, 0.3046884536743164, 0.380337119102478, 0.699865996837616, 0.7464215755462646, 0.390849232673645, 0.8992515802383423, 0.612838089466095, 0.4444986581802368, 0.9558160901069641, 0.7544505000114441, 0.394461989402771, 0.8434802293777466, 0.5277895331382751, 0.20150727033615112, 0.8635226488113403, 0.588026762008667, 0.260425865650177, 0.11915439367294312, 0.630763053894043, 0.35954052209854126, 0.6132745742797852, 0.5505245923995972, 0.557021975517273, 0.6679773926734924, 0.5876642465591431, 0.6560839414596558, 0.9582574963569641, 0.3779008984565735, 0.9340549111366272, 0.18399536609649658, 0.25960028171539307, 0.026599764823913574, 0.32735109329223633, 0.904168963432312, 0.14129406213760376, 0.82865971326828, 0.1153140664100647, 0.32517731189727783, 0.5030156970024109, 0.751384437084198, 0.6885764598846436, 0.34171855449676514, 0.6160555481910706, 0.16125458478927612, 0.9489563703536987, 0.1679406762123108, 0.20235764980316162, 0.881341278553009, 0.5121568441390991, 0.92611163854599, 0.4272810220718384, 0.2771685719490051, 0.9968149065971375, 0.15203642845153809, 0.4641759395599365, 0.413804292678833, 0.651200532913208, 0.32379841804504395, 0.6092511415481567, 0.8478893637657166, 0.9164987802505493, 0.7830571532249451, 0.8511074185371399, 0.15598517656326294, 0.37074369192123413, 0.4150542616844177, 0.5372101068496704, 0.14770209789276123, 0.23627835512161255, 0.06493926048278809, 0.09793734550476074, 0.8952625393867493, 0.3102150559425354, 0.43397819995880127, 0.5844067931175232, 0.3500853180885315, 0.22773021459579468, 0.05151098966598511, 0.6001721620559692, 0.3343912363052368, 0.46624046564102173, 0.549772322177887, 0.7372040748596191, 0.042217373847961426, 0.49783337116241455, 0.6191272735595703, 0.10424506664276123, 0.25230395793914795, 0.6359071135520935, 0.2743602991104126, 0.9432371258735657, 0.9513012170791626, 0.03317856788635254, 0.6779598593711853, 0.16297829151153564, 0.7508324384689331, 0.45239394903182983, 0.4571657180786133, 0.2818666100502014, 0.17718660831451416, 0.3443911671638489, 0.9507452845573425, 0.9657842516899109, 0.5738788843154907, 0.1653851866722107, 0.9145194292068481, 0.3741937279701233, 0.05001652240753174, 0.984059751033783, 0.6154639720916748, 0.6324515342712402, 0.08717000484466553, 0.894913911819458, 0.15595513582229614, 0.7157255411148071, 0.7270150780677795, 0.8960562944412231, 0.8680596947669983, 0.4046359062194824, 0.13201695680618286, 0.7696574926376343, 0.6563123464584351, 0.4042320251464844, 0.7971198558807373, 0.3859425187110901, 0.1174202561378479, 0.689612865447998, 0.3647807240486145, 0.4592200517654419, 0.05106186866760254, 0.6523975133895874, 0.2841695547103882, 0.07473695278167725, 0.5434084534645081, 0.20653265714645386, 0.31296950578689575, 0.3648838400840759, 0.5561838150024414, 0.34996211528778076, 0.4311484098434448, 0.7667861580848694, 0.8519712686538696, 0.9926041960716248, 0.11216002702713013, 0.632348895072937, 0.12260669469833374, 0.4656309485435486, 0.9705881476402283, 0.21109485626220703, 0.08676731586456299, 0.16977214813232422, 0.16263240575790405, 0.685645580291748, 0.14296847581863403, 0.6883099675178528, 0.046499550342559814, 0.026215791702270508, 0.2830098867416382, 0.5826435685157776, 0.5168894529342651, 0.21555876731872559, 0.6396673917770386, 0.5840371251106262, 0.8724454641342163, 0.13193941116333008, 0.6795164942741394, 0.15844780206680298, 0.27272599935531616, 0.6294320225715637, 0.35720330476760864, 0.8047906160354614, 0.392769455909729, 0.12731897830963135, 0.0994114875793457, 0.06120210886001587, 0.1784428358078003, 0.35258960723876953, 0.24094289541244507, 0.49197083711624146, 0.5611958503723145, 0.9082142114639282, 0.010551989078521729, 0.22507983446121216, 0.007577002048492432, 0.27637380361557007, 0.8734598159790039, 0.623768150806427, 0.10755449533462524, 0.4494497776031494, 0.06208443641662598, 0.5997845530509949, 0.8712562322616577, 0.4752771854400635, 0.3971136808395386, 0.01836305856704712, 0.4282991290092468, 0.40080422163009644, 0.757233738899231, 0.17318272590637207, 0.42513030767440796, 0.06316602230072021, 0.5725610256195068, 0.14671063423156738, 0.8214869499206543], [2, 0, 13, 12, 1, 0, 6, 3, 0, 4, 10, 0, 9, 3, 14, 13, 12, 0, 0, 11, 8, 14, 0, 8, 13, 6, 14, 9, 2, 10], [0.3678196668624878, 0.5077707171440125, 0.7048060894012451, 0.824059784412384, 0.5817335247993469, 0.1481603980064392, 0.7114542722702026, 0.46050769090652466, 0.2019960880279541, 0.47866880893707275, 0.6982521414756775, 0.8433347344398499, 0.3427056074142456, 0.182439923286438, 0.37574678659439087, 0.93959641456604, 0.5769227147102356, 0.5917701125144958, 0.30993229150772095, 0.26509368419647217, 0.3919970393180847, 0.04614824056625366, 0.09493398666381836, 0.8279429078102112, 0.29439830780029297, 0.42815113067626953, 0.4152073264122009, 0.33196568489074707, 0.24249297380447388, 0.9800586700439453]]}
30 changes: 30 additions & 0 deletions examples/onnx/scatter_elements/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
pytorch2.0.1:�
O
input
input1
input2output/ScatterElements"ScatterElements*
axis� torch_jitZ%
input


batch_size

Z&
input1


batch_size

Z&
input2


batch_size

b&
output


batch_size

B
34 changes: 34 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ pub enum HybridOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
}

impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::ScatterElements { .. } => vec![0, 2],
_ => vec![],
}
}
Expand Down Expand Up @@ -176,6 +181,20 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
(res.clone(), inter_equals)
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
let src = inputs[2].clone().map(|x| felt_to_i128(x));
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
(res.clone(), vec![])
} else {
let idx = inputs[1].clone().map(|x| felt_to_i128(x));
let inter_equals: Vec<Tensor<i128>> =
vec![Tensor::from(0..x.dims()[*dim] as i128)];
let res = tensor::ops::scatter(&x, &idx.map(|x| x as usize), &src, *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::MaxPool2d {
padding,
stride,
Expand Down Expand Up @@ -253,6 +272,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim),
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
HybridOp::OneHot { dim, num_classes } => {
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
}
Expand Down Expand Up @@ -281,6 +301,19 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::MaxPool2d {
padding,
stride,
Expand Down Expand Up @@ -395,6 +428,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Gather { .. }
| HybridOp::OneHot { .. }
| HybridOp::GatherElements { .. }
| HybridOp::ScatterElements { .. }
| HybridOp::Equals { .. } => {
vec![LookupOp::KroneckerDelta]
}
Expand Down
135 changes: 125 additions & 10 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
input.flatten();

// assert we have a single index
// assert_eq!(index.dims().iter().product::<usize>(), 1);
assert_eq!(index.dims().iter().product::<usize>(), 1);
assert!(dim_indices.all_prev_assigned() || region.is_dummy());

let is_assigned = !input.any_unknowns() && !index.any_unknowns();
Expand Down Expand Up @@ -715,13 +715,19 @@ pub fn gather<F: PrimeField + TensorType + PartialOrd>(
let (mut input, mut index) = (values[0].clone(), values[1].clone());
index.flatten();

let mut assigned_len = vec![];
if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.inputs[1], &index)?;
assigned_len.push(index.len());
}

if !assigned_len.is_empty() {
region.increment(assigned_len.iter().max().unwrap().clone());
}
region.increment(std::cmp::max(input.len(), index.len()));

// Calculate the output tensor size
let input_dims = input.dims();
Expand Down Expand Up @@ -858,6 +864,115 @@ pub fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}

/// Gather accumulated layout
pub fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone());

assert_eq!(input.dims().len(), index.dims().len());

let mut assigned_len = vec![];

if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
assigned_len.push(input.len());
}
if !index.all_prev_assigned() {
index = region.assign(&config.inputs[1], &index)?;
assigned_len.push(index.len());
}
if !src.all_prev_assigned() {
src = region.assign(&config.output, &src)?;
assigned_len.push(src.len());
}

if !assigned_len.is_empty() {
region.increment(*assigned_len.iter().max().unwrap());
}

// Calculate the output tensor size
let input_dim = input.dims()[dim];
let output_size = index.dims().to_vec();

// these will be assigned as constants
let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x)));
indices.set_visibility(&crate::graph::Visibility::Fixed);
let indices = region.assign(&config.inputs[1], &indices.into())?;
region.increment(indices.len());

// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let mut unit = Tensor::from(vec![F::from(1)].into_iter());
unit.set_visibility(&crate::graph::Visibility::Fixed);
let unit: ValTensor<F> = unit.into();
region.assign(&config.inputs[1], &unit).unwrap();
region.increment(1);

let mut output = Tensor::new(None, &output_size)?;

let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| -> () {
let coord = cartesian_coord[i].clone();
let index_val = index.get_inner_tensor().unwrap().get(&coord);

let src_val = src.get_inner_tensor().unwrap().get(&coord);
let src_valtensor: ValTensor<F> = Tensor::from([src_val.clone()].into_iter()).into();

let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dim;

let mut sliced_input = input.get_slice(&slice).unwrap();
sliced_input.flatten();

let index_valtensor: ValTensor<F> = Tensor::from([index_val.clone()].into_iter()).into();

let mask = equals(config, region, &[index_valtensor, indices.clone()]).unwrap();

let one_minus_mask =
pairwise(config, region, &[unit.clone(), mask.clone()], BaseOp::Sub).unwrap();

let pairwise_prod = pairwise(config, region, &[src_valtensor, mask], BaseOp::Mult).unwrap();
let pairwise_prod_2 = pairwise(
config,
region,
&[sliced_input, one_minus_mask],
BaseOp::Mult,
)
.unwrap();

let res = pairwise(
config,
region,
&[pairwise_prod, pairwise_prod_2],
BaseOp::Add,
)
.unwrap();

let input_cartesian_coord = slice.into_iter().multi_cartesian_product();

let mutable_input_inner = input.get_inner_tensor_mut().unwrap();

for (i, r) in res.get_inner_tensor().unwrap().iter().enumerate() {
let coord = input_cartesian_coord.clone().nth(i).unwrap();
*mutable_input_inner.get_mut(&coord) = r.clone();
}
};

output.iter_mut().enumerate().for_each(|(i, o)| {
*o = inner_loop_function(i, region);
});

Ok(input)
}

/// sum accumulated layout
pub fn sum<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down Expand Up @@ -1750,20 +1865,21 @@ pub fn conv<F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::m

// we specifically want to use the same kernel and image for all the convolutions and need to enforce this by assigning them
// 1. assign the kernel
let mut assigned_kernel_len = 0;
let mut assigned_len = vec![];

if !kernel.all_prev_assigned() {
kernel = region.assign(&config.inputs[0], &kernel)?;
assigned_kernel_len = kernel.len();
assigned_len.push(kernel.len());
}
// 2. assign the image
let mut assigned_image_len = 0;
if !image.all_prev_assigned() {
image = region.assign(&config.inputs[1], &image)?;
assigned_image_len = image.len();
assigned_len.push(image.len());
}

// increment the region
region.increment(std::cmp::max(assigned_kernel_len, assigned_image_len));
if !assigned_len.is_empty() {
region.increment(*assigned_len.iter().max().unwrap());
}

let og_dims = image.dims().to_vec();

Expand Down Expand Up @@ -2446,14 +2562,13 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(

if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
region.increment(input.len());
}

if input.dims().len() == 1 {
return op(config, region, &[input]);
}

region.increment(input.len());

// Calculate the output tensor size
let input_dims = input.dims();

Expand Down
Loading