From 723b05441320e3d78e0227b7bf396af4d02b16f0 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Wed, 27 Sep 2023 01:04:03 +0100 Subject: [PATCH] feat: scatter elements argument (#503) --- examples/onnx/scatter_elements/gen.py | 46 +++++++ examples/onnx/scatter_elements/input.json | 1 + examples/onnx/scatter_elements/network.onnx | 30 +++++ src/circuit/ops/hybrid.rs | 34 +++++ src/circuit/ops/layouts.rs | 135 ++++++++++++++++++-- src/graph/utilities.rs | 43 ++++++- src/tensor/mod.rs | 22 +++- src/tensor/ops.rs | 61 +++++++++ src/tensor/val.rs | 2 +- tests/integration_tests.rs | 5 +- 10 files changed, 363 insertions(+), 16 deletions(-) create mode 100644 examples/onnx/scatter_elements/gen.py create mode 100644 examples/onnx/scatter_elements/input.json create mode 100644 examples/onnx/scatter_elements/network.onnx diff --git a/examples/onnx/scatter_elements/gen.py b/examples/onnx/scatter_elements/gen.py new file mode 100644 index 000000000..aca91f392 --- /dev/null +++ b/examples/onnx/scatter_elements/gen.py @@ -0,0 +1,46 @@ +from torch import nn +import torch +import json +import numpy as np + + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, w, x, src): + # scatter_elements + return w.scatter(2, x, src) + + +circuit = MyModel() + + +w = torch.rand(1, 15, 18) +src = torch.rand(1, 15, 2) +x = torch.randint(0, 15, (1, 15, 2)) + +torch.onnx.export(circuit, (w, x, src), "network.onnx", + export_params=True, # store the trained parameter weights inside the model file + opset_version=15, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + # the model's input names + input_names=['input', 'input1', 'input2'], + output_names=['output'], # the model's output names + dynamic_axes={'input': {0: 'batch_size'}, # variable length axes + 'input1': {0: 'batch_size'}, + 'input2': {0: 'batch_size'}, + 'output': {0: 'batch_size'}}) + + +d = ((w).detach().numpy()).reshape([-1]).tolist() +d1 = ((x).detach().numpy()).reshape([-1]).tolist() +d2 = ((src).detach().numpy()).reshape([-1]).tolist() + + +data = dict( + input_data=[d, d1, d2], +) + +# Serialize data into file: +json.dump(data, open("input.json", 'w')) diff --git a/examples/onnx/scatter_elements/input.json b/examples/onnx/scatter_elements/input.json new file mode 100644 index 000000000..ef333c1b8 --- /dev/null +++ b/examples/onnx/scatter_elements/input.json @@ -0,0 +1 @@ +{"input_data": [[0.4006722569465637, 0.2608814835548401, 0.018605709075927734, 0.2076982855796814, 0.09738868474960327, 0.7288642525672913, 0.5645512342453003, 0.1282898187637329, 0.4371366500854492, 0.06852233409881592, 0.2802664637565613, 0.13553869724273682, 0.34616756439208984, 0.07160574197769165, 0.003862321376800537, 0.48938989639282227, 0.44164425134658813, 0.5328972935676575, 0.07106280326843262, 0.09725075960159302, 0.2589840292930603, 0.23286551237106323, 0.3244437575340271, 0.6505903005599976, 0.938034176826477, 0.26783543825149536, 0.786094069480896, 0.18789011240005493, 0.606765627861023, 0.19948190450668335, 0.3781137466430664, 0.8510081768035889, 0.13566946983337402, 0.9276034832000732, 0.5269846320152283, 0.3013773560523987, 0.5438669919967651, 0.12605935335159302, 0.4066329002380371, 0.24718689918518066, 0.4687347412109375, 0.23328912258148193, 0.8731061220169067, 0.3781728744506836, 0.5772581696510315, 0.8565052151679993, 0.7540851831436157, 0.779710590839386, 0.2353827953338623, 0.5320121645927429, 0.5085064172744751, 0.6413846015930176, 0.0774533748626709, 0.43543362617492676, 0.3591287136077881, 0.0804855227470398, 0.24766361713409424, 0.29354000091552734, 0.424951434135437, 0.43333709239959717, 0.40345197916030884, 0.8565409183502197, 0.06852656602859497, 0.04801344871520996, 0.1688656210899353, 0.9150650501251221, 0.3046884536743164, 0.380337119102478, 0.699865996837616, 0.7464215755462646, 0.390849232673645, 0.8992515802383423, 0.612838089466095, 0.4444986581802368, 0.9558160901069641, 0.7544505000114441, 0.394461989402771, 0.8434802293777466, 0.5277895331382751, 0.20150727033615112, 0.8635226488113403, 0.588026762008667, 0.260425865650177, 0.11915439367294312, 0.630763053894043, 0.35954052209854126, 0.6132745742797852, 0.5505245923995972, 0.557021975517273, 0.6679773926734924, 0.5876642465591431, 0.6560839414596558, 0.9582574963569641, 0.3779008984565735, 0.9340549111366272, 0.18399536609649658, 0.25960028171539307, 0.026599764823913574, 0.32735109329223633, 0.904168963432312, 0.14129406213760376, 0.82865971326828, 0.1153140664100647, 0.32517731189727783, 0.5030156970024109, 0.751384437084198, 0.6885764598846436, 0.34171855449676514, 0.6160555481910706, 0.16125458478927612, 0.9489563703536987, 0.1679406762123108, 0.20235764980316162, 0.881341278553009, 0.5121568441390991, 0.92611163854599, 0.4272810220718384, 0.2771685719490051, 0.9968149065971375, 0.15203642845153809, 0.4641759395599365, 0.413804292678833, 0.651200532913208, 0.32379841804504395, 0.6092511415481567, 0.8478893637657166, 0.9164987802505493, 0.7830571532249451, 0.8511074185371399, 0.15598517656326294, 0.37074369192123413, 0.4150542616844177, 0.5372101068496704, 0.14770209789276123, 0.23627835512161255, 0.06493926048278809, 0.09793734550476074, 0.8952625393867493, 0.3102150559425354, 0.43397819995880127, 0.5844067931175232, 0.3500853180885315, 0.22773021459579468, 0.05151098966598511, 0.6001721620559692, 0.3343912363052368, 0.46624046564102173, 0.549772322177887, 0.7372040748596191, 0.042217373847961426, 0.49783337116241455, 0.6191272735595703, 0.10424506664276123, 0.25230395793914795, 0.6359071135520935, 0.2743602991104126, 0.9432371258735657, 0.9513012170791626, 0.03317856788635254, 0.6779598593711853, 0.16297829151153564, 0.7508324384689331, 0.45239394903182983, 0.4571657180786133, 0.2818666100502014, 0.17718660831451416, 0.3443911671638489, 0.9507452845573425, 0.9657842516899109, 0.5738788843154907, 0.1653851866722107, 0.9145194292068481, 0.3741937279701233, 0.05001652240753174, 0.984059751033783, 0.6154639720916748, 0.6324515342712402, 0.08717000484466553, 0.894913911819458, 0.15595513582229614, 0.7157255411148071, 0.7270150780677795, 0.8960562944412231, 0.8680596947669983, 0.4046359062194824, 0.13201695680618286, 0.7696574926376343, 0.6563123464584351, 0.4042320251464844, 0.7971198558807373, 0.3859425187110901, 0.1174202561378479, 0.689612865447998, 0.3647807240486145, 0.4592200517654419, 0.05106186866760254, 0.6523975133895874, 0.2841695547103882, 0.07473695278167725, 0.5434084534645081, 0.20653265714645386, 0.31296950578689575, 0.3648838400840759, 0.5561838150024414, 0.34996211528778076, 0.4311484098434448, 0.7667861580848694, 0.8519712686538696, 0.9926041960716248, 0.11216002702713013, 0.632348895072937, 0.12260669469833374, 0.4656309485435486, 0.9705881476402283, 0.21109485626220703, 0.08676731586456299, 0.16977214813232422, 0.16263240575790405, 0.685645580291748, 0.14296847581863403, 0.6883099675178528, 0.046499550342559814, 0.026215791702270508, 0.2830098867416382, 0.5826435685157776, 0.5168894529342651, 0.21555876731872559, 0.6396673917770386, 0.5840371251106262, 0.8724454641342163, 0.13193941116333008, 0.6795164942741394, 0.15844780206680298, 0.27272599935531616, 0.6294320225715637, 0.35720330476760864, 0.8047906160354614, 0.392769455909729, 0.12731897830963135, 0.0994114875793457, 0.06120210886001587, 0.1784428358078003, 0.35258960723876953, 0.24094289541244507, 0.49197083711624146, 0.5611958503723145, 0.9082142114639282, 0.010551989078521729, 0.22507983446121216, 0.007577002048492432, 0.27637380361557007, 0.8734598159790039, 0.623768150806427, 0.10755449533462524, 0.4494497776031494, 0.06208443641662598, 0.5997845530509949, 0.8712562322616577, 0.4752771854400635, 0.3971136808395386, 0.01836305856704712, 0.4282991290092468, 0.40080422163009644, 0.757233738899231, 0.17318272590637207, 0.42513030767440796, 0.06316602230072021, 0.5725610256195068, 0.14671063423156738, 0.8214869499206543], [2, 0, 13, 12, 1, 0, 6, 3, 0, 4, 10, 0, 9, 3, 14, 13, 12, 0, 0, 11, 8, 14, 0, 8, 13, 6, 14, 9, 2, 10], [0.3678196668624878, 0.5077707171440125, 0.7048060894012451, 0.824059784412384, 0.5817335247993469, 0.1481603980064392, 0.7114542722702026, 0.46050769090652466, 0.2019960880279541, 0.47866880893707275, 0.6982521414756775, 0.8433347344398499, 0.3427056074142456, 0.182439923286438, 0.37574678659439087, 0.93959641456604, 0.5769227147102356, 0.5917701125144958, 0.30993229150772095, 0.26509368419647217, 0.3919970393180847, 0.04614824056625366, 0.09493398666381836, 0.8279429078102112, 0.29439830780029297, 0.42815113067626953, 0.4152073264122009, 0.33196568489074707, 0.24249297380447388, 0.9800586700439453]]} \ No newline at end of file diff --git a/examples/onnx/scatter_elements/network.onnx b/examples/onnx/scatter_elements/network.onnx new file mode 100644 index 000000000..34eeeb329 --- /dev/null +++ b/examples/onnx/scatter_elements/network.onnx @@ -0,0 +1,30 @@ +pytorch2.0.1:û +O +input +input1 +input2output/ScatterElements"ScatterElements* +axis  torch_jitZ% +input + +  +batch_size + +Z& +input1 + +  +batch_size + +Z& +input2 + +  +batch_size + +b& +output + +  +batch_size + +B \ No newline at end of file diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 57cd76380..ebd916d85 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -57,6 +57,10 @@ pub enum HybridOp { dim: usize, constant_idx: Option>, }, + ScatterElements { + dim: usize, + constant_idx: Option>, + }, } impl Op for HybridOp { @@ -64,6 +68,7 @@ impl Op for HybridOp { fn requires_homogenous_input_scales(&self) -> Vec { match self { HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1], + HybridOp::ScatterElements { .. } => vec![0, 2], _ => vec![], } } @@ -176,6 +181,20 @@ impl Op for HybridOp { (res.clone(), inter_equals) } } + HybridOp::ScatterElements { dim, constant_idx } => { + let src = inputs[2].clone().map(|x| felt_to_i128(x)); + if let Some(idx) = constant_idx { + log::debug!("idx: {}", idx.show()); + let res = tensor::ops::scatter(&x, idx, &src, *dim)?; + (res.clone(), vec![]) + } else { + let idx = inputs[1].clone().map(|x| felt_to_i128(x)); + let inter_equals: Vec> = + vec![Tensor::from(0..x.dims()[*dim] as i128)]; + let res = tensor::ops::scatter(&x, &idx.map(|x| x as usize), &src, *dim)?; + (res.clone(), inter_equals) + } + } HybridOp::MaxPool2d { padding, stride, @@ -253,6 +272,7 @@ impl Op for HybridOp { HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim), HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim), HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim), + HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim), HybridOp::OneHot { dim, num_classes } => { format!("ONEHOT (dim={}, num_classes={})", dim, num_classes) } @@ -281,6 +301,19 @@ impl Op for HybridOp { layouts::gather_elements(config, region, values[..].try_into()?, *dim)? } } + HybridOp::ScatterElements { dim, constant_idx } => { + if let Some(idx) = constant_idx { + tensor::ops::scatter( + values[0].get_inner_tensor()?, + idx, + values[1].get_inner_tensor()?, + *dim, + )? + .into() + } else { + layouts::scatter_elements(config, region, values[..].try_into()?, *dim)? + } + } HybridOp::MaxPool2d { padding, stride, @@ -395,6 +428,7 @@ impl Op for HybridOp { HybridOp::Gather { .. } | HybridOp::OneHot { .. } | HybridOp::GatherElements { .. } + | HybridOp::ScatterElements { .. } | HybridOp::Equals { .. } => { vec![LookupOp::KroneckerDelta] } diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 0af2b034f..dd07008d3 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -553,7 +553,7 @@ fn select( input.flatten(); // assert we have a single index - // assert_eq!(index.dims().iter().product::(), 1); + assert_eq!(index.dims().iter().product::(), 1); assert!(dim_indices.all_prev_assigned() || region.is_dummy()); let is_assigned = !input.any_unknowns() && !index.any_unknowns(); @@ -715,13 +715,19 @@ pub fn gather( let (mut input, mut index) = (values[0].clone(), values[1].clone()); index.flatten(); + let mut assigned_len = vec![]; if !input.all_prev_assigned() { input = region.assign(&config.inputs[0], &input)?; + assigned_len.push(input.len()); } if !index.all_prev_assigned() { index = region.assign(&config.inputs[1], &index)?; + assigned_len.push(index.len()); + } + + if !assigned_len.is_empty() { + region.increment(assigned_len.iter().max().unwrap().clone()); } - region.increment(std::cmp::max(input.len(), index.len())); // Calculate the output tensor size let input_dims = input.dims(); @@ -858,6 +864,115 @@ pub fn gather_elements( Ok(output.into()) } +/// Gather accumulated layout +pub fn scatter_elements( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 3], + dim: usize, +) -> Result, Box> { + let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone()); + + assert_eq!(input.dims().len(), index.dims().len()); + + let mut assigned_len = vec![]; + + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + assigned_len.push(input.len()); + } + if !index.all_prev_assigned() { + index = region.assign(&config.inputs[1], &index)?; + assigned_len.push(index.len()); + } + if !src.all_prev_assigned() { + src = region.assign(&config.output, &src)?; + assigned_len.push(src.len()); + } + + if !assigned_len.is_empty() { + region.increment(*assigned_len.iter().max().unwrap()); + } + + // Calculate the output tensor size + let input_dim = input.dims()[dim]; + let output_size = index.dims().to_vec(); + + // these will be assigned as constants + let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.into())?; + region.increment(indices.len()); + + // Allocate memory for the output tensor + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(&crate::graph::Visibility::Fixed); + let unit: ValTensor = unit.into(); + region.assign(&config.inputs[1], &unit).unwrap(); + region.increment(1); + + let mut output = Tensor::new(None, &output_size)?; + + let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| -> () { + let coord = cartesian_coord[i].clone(); + let index_val = index.get_inner_tensor().unwrap().get(&coord); + + let src_val = src.get_inner_tensor().unwrap().get(&coord); + let src_valtensor: ValTensor = Tensor::from([src_val.clone()].into_iter()).into(); + + let mut slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + slice[dim] = 0..input_dim; + + let mut sliced_input = input.get_slice(&slice).unwrap(); + sliced_input.flatten(); + + let index_valtensor: ValTensor = Tensor::from([index_val.clone()].into_iter()).into(); + + let mask = equals(config, region, &[index_valtensor, indices.clone()]).unwrap(); + + let one_minus_mask = + pairwise(config, region, &[unit.clone(), mask.clone()], BaseOp::Sub).unwrap(); + + let pairwise_prod = pairwise(config, region, &[src_valtensor, mask], BaseOp::Mult).unwrap(); + let pairwise_prod_2 = pairwise( + config, + region, + &[sliced_input, one_minus_mask], + BaseOp::Mult, + ) + .unwrap(); + + let res = pairwise( + config, + region, + &[pairwise_prod, pairwise_prod_2], + BaseOp::Add, + ) + .unwrap(); + + let input_cartesian_coord = slice.into_iter().multi_cartesian_product(); + + let mutable_input_inner = input.get_inner_tensor_mut().unwrap(); + + for (i, r) in res.get_inner_tensor().unwrap().iter().enumerate() { + let coord = input_cartesian_coord.clone().nth(i).unwrap(); + *mutable_input_inner.get_mut(&coord) = r.clone(); + } + }; + + output.iter_mut().enumerate().for_each(|(i, o)| { + *o = inner_loop_function(i, region); + }); + + Ok(input) +} + /// sum accumulated layout pub fn sum( config: &BaseConfig, @@ -1750,20 +1865,21 @@ pub fn conv( if !input.all_prev_assigned() { input = region.assign(&config.inputs[0], &input)?; + region.increment(input.len()); } if input.dims().len() == 1 { return op(config, region, &[input]); } - region.increment(input.len()); - // Calculate the output tensor size let input_dims = input.dims(); diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index bcf5bae9a..670ff4741 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp}; #[cfg(not(target_arch = "wasm32"))] use tract_onnx::tract_core::ops::{ - array::{Gather, GatherElements, OneHot, Slice, Topk}, + array::{Gather, GatherElements, OneHot, ScatterElements, Slice, Topk}, change_axes::AxisOp, cnn::DeconvUnary, einsum::EinSum, @@ -281,6 +281,45 @@ pub fn new_op_from_onnx( num_classes, }) } + "ScatterElements" => { + if inputs.len() != 3 { + return Err(Box::new(GraphError::InvalidDims( + idx, + "scatter elements".to_string(), + ))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + + let mut op = + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements { + dim: axis, + constant_idx: None, + }); + + // if param_visibility.is_public() { + if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_const(); + deleted_indices.push(1); + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements { + dim: axis, + constant_idx: Some(c.raw_values.map(|x| x as usize)), + }) + } + // } + + if inputs[1].opkind().is_input() { + inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input { + scale: 0, + datum_type: InputType::TDim, + })); + inputs[1].bump_scale(0); + } + + op + + // Extract the max value + } "GatherElements" => { if inputs.len() != 2 { return Err(Box::new(GraphError::InvalidDims( @@ -301,7 +340,7 @@ pub fn new_op_from_onnx( if let Some(c) = inputs[1].opkind().get_mutable_constant() { inputs[1].decrement_const(); deleted_indices.push(inputs.len() - 1); - op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather { + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements { dim: axis, constant_idx: Some(c.raw_values.map(|x| x as usize)), }) diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 4a8bcea53..6774dc12f 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -538,6 +538,27 @@ impl Tensor { self[index].clone() } + /// Get a mutable array index from rows / columns indices. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[2, 3, 5]).unwrap(); + /// + /// a[1*15 + 1*5 + 1] = 5; + /// assert_eq!(a.get(&[1, 1, 1]), 5); + /// ``` + pub fn get_mut(&mut self, indices: &[usize]) -> &mut T { + assert_eq!(self.dims.len(), indices.len()); + let mut index = 0; + let mut d = 1; + for i in (0..indices.len()).rev() { + assert!(self.dims[i] > indices[i]); + index += indices[i] * d; + d *= self.dims[i]; + } + &mut self[index] + } + /// Get a single value from the Tensor. /// /// ``` @@ -567,7 +588,6 @@ impl Tensor { } /// Get a slice from the Tensor. - /// /// ``` /// use ezkl::tensor::Tensor; /// let mut a = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index f96315b15..05fdd30cc 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1186,6 +1186,67 @@ pub fn gather( Ok(output) } +/// Scatters a tensor along a dimension. +/// # Arguments +/// * `input` - Tensor +/// * `dim` - Dimension to scatter along +/// * `index` - Tensor of indices to scatter +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::scatter; +/// let x = Tensor::::new( +/// Some(&[1.0, 2.0, 3.0, 4.0]), +/// &[2, 2], +/// ).unwrap(); +/// let src = Tensor::::new( +/// Some(&[5.0, 6.0, 7.0, 8.0]), +/// &[2, 2], +/// ).unwrap(); +/// let index = Tensor::::new( +/// Some(&[0, 0, 1, 0]), +/// &[2, 2], +/// ).unwrap(); +/// let result = scatter(&x, &index, &src, 0).unwrap(); +/// let expected = Tensor::::new(Some(&[5.0, 8.0, 7.0, 4.0]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn scatter( + input: &Tensor, + index: &Tensor, + src: &Tensor, + dim: usize, +) -> Result, TensorError> { + // Writes all values from the tensor src into self at the indices specified in the index tensor. + // For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim. + assert_eq!(index.dims(), src.dims()); + // Calculate the output tensor size + let src_size = src.dims().to_vec(); + + // For a 3-D tensor, self is updated as: + // self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + // self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + // self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + let mut output = input.clone(); + + let cartesian_coord = src_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + cartesian_coord.iter().for_each(|coord| { + let mut new_coord = coord.clone(); + let index_val = index.get(&coord); + new_coord[dim] = index_val; + let val = src.get(&coord); + output.set(&new_coord, val); + }); + + Ok(output) +} + /// Gathers a tensor along a dimension. /// # Arguments /// * `input` - Tensor diff --git a/src/tensor/val.rs b/src/tensor/val.rs index c34dcf342..51f1708cb 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -378,7 +378,7 @@ impl ValTensor { Ok(slice) } - /// Calls `get_slice` on the inner tensor. + /// Calls `get_single_elem` on the inner tensor. pub fn get_single_elem(&self, index: usize) -> Result, Box> { let slice = match self { ValTensor::Value { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 44b8c5a3a..5de1cdab9 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -188,7 +188,7 @@ mod native_tests { "1l_prelu", ]; - const TESTS: [&str; 58] = [ + const TESTS: [&str; 59] = [ "1l_mlp", "1l_slice", "1l_concat", @@ -250,6 +250,7 @@ mod native_tests { "less", "xgboost_reg", "1l_powf", + "scatter_elements", ]; const WASM_TESTS: [&str; 48] = [ @@ -479,7 +480,7 @@ mod native_tests { - seq!(N in 0..=57 { + seq!(N in 0..=58 { #(#[test_case(TESTS[N])])* fn model_serialization_(test: &str) { let test_dir = TempDir::new(test).unwrap();