From 3b38c8c23d8c1ed65f4c6bc480dbe3fee3f42e1f Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Mon, 25 Sep 2023 20:37:36 +0100 Subject: [PATCH 1/4] feat: scatter elements argument --- examples/onnx/scatter_elements/gen.py | 46 +++++++++++++++ examples/onnx/scatter_elements/input.json | 1 + examples/onnx/scatter_elements/network.onnx | 30 ++++++++++ src/circuit/ops/hybrid.rs | 34 +++++++++++ src/circuit/ops/layouts.rs | 46 +++++++++++++++ src/graph/utilities.rs | 43 +++++++++++++- src/tensor/ops.rs | 62 +++++++++++++++++++++ tests/integration_tests.rs | 5 +- 8 files changed, 263 insertions(+), 4 deletions(-) create mode 100644 examples/onnx/scatter_elements/gen.py create mode 100644 examples/onnx/scatter_elements/input.json create mode 100644 examples/onnx/scatter_elements/network.onnx diff --git a/examples/onnx/scatter_elements/gen.py b/examples/onnx/scatter_elements/gen.py new file mode 100644 index 000000000..aca91f392 --- /dev/null +++ b/examples/onnx/scatter_elements/gen.py @@ -0,0 +1,46 @@ +from torch import nn +import torch +import json +import numpy as np + + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, w, x, src): + # scatter_elements + return w.scatter(2, x, src) + + +circuit = MyModel() + + +w = torch.rand(1, 15, 18) +src = torch.rand(1, 15, 2) +x = torch.randint(0, 15, (1, 15, 2)) + +torch.onnx.export(circuit, (w, x, src), "network.onnx", + export_params=True, # store the trained parameter weights inside the model file + opset_version=15, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + # the model's input names + input_names=['input', 'input1', 'input2'], + output_names=['output'], # the model's output names + dynamic_axes={'input': {0: 'batch_size'}, # variable length axes + 'input1': {0: 'batch_size'}, + 'input2': {0: 'batch_size'}, + 'output': {0: 'batch_size'}}) + + +d = ((w).detach().numpy()).reshape([-1]).tolist() +d1 = ((x).detach().numpy()).reshape([-1]).tolist() +d2 = ((src).detach().numpy()).reshape([-1]).tolist() + + +data = dict( + input_data=[d, d1, d2], +) + +# Serialize data into file: +json.dump(data, open("input.json", 'w')) diff --git a/examples/onnx/scatter_elements/input.json b/examples/onnx/scatter_elements/input.json new file mode 100644 index 000000000..ef333c1b8 --- /dev/null +++ b/examples/onnx/scatter_elements/input.json @@ -0,0 +1 @@ +{"input_data": [[0.4006722569465637, 0.2608814835548401, 0.018605709075927734, 0.2076982855796814, 0.09738868474960327, 0.7288642525672913, 0.5645512342453003, 0.1282898187637329, 0.4371366500854492, 0.06852233409881592, 0.2802664637565613, 0.13553869724273682, 0.34616756439208984, 0.07160574197769165, 0.003862321376800537, 0.48938989639282227, 0.44164425134658813, 0.5328972935676575, 0.07106280326843262, 0.09725075960159302, 0.2589840292930603, 0.23286551237106323, 0.3244437575340271, 0.6505903005599976, 0.938034176826477, 0.26783543825149536, 0.786094069480896, 0.18789011240005493, 0.606765627861023, 0.19948190450668335, 0.3781137466430664, 0.8510081768035889, 0.13566946983337402, 0.9276034832000732, 0.5269846320152283, 0.3013773560523987, 0.5438669919967651, 0.12605935335159302, 0.4066329002380371, 0.24718689918518066, 0.4687347412109375, 0.23328912258148193, 0.8731061220169067, 0.3781728744506836, 0.5772581696510315, 0.8565052151679993, 0.7540851831436157, 0.779710590839386, 0.2353827953338623, 0.5320121645927429, 0.5085064172744751, 0.6413846015930176, 0.0774533748626709, 0.43543362617492676, 0.3591287136077881, 0.0804855227470398, 0.24766361713409424, 0.29354000091552734, 0.424951434135437, 0.43333709239959717, 0.40345197916030884, 0.8565409183502197, 0.06852656602859497, 0.04801344871520996, 0.1688656210899353, 0.9150650501251221, 0.3046884536743164, 0.380337119102478, 0.699865996837616, 0.7464215755462646, 0.390849232673645, 0.8992515802383423, 0.612838089466095, 0.4444986581802368, 0.9558160901069641, 0.7544505000114441, 0.394461989402771, 0.8434802293777466, 0.5277895331382751, 0.20150727033615112, 0.8635226488113403, 0.588026762008667, 0.260425865650177, 0.11915439367294312, 0.630763053894043, 0.35954052209854126, 0.6132745742797852, 0.5505245923995972, 0.557021975517273, 0.6679773926734924, 0.5876642465591431, 0.6560839414596558, 0.9582574963569641, 0.3779008984565735, 0.9340549111366272, 0.18399536609649658, 0.25960028171539307, 0.026599764823913574, 0.32735109329223633, 0.904168963432312, 0.14129406213760376, 0.82865971326828, 0.1153140664100647, 0.32517731189727783, 0.5030156970024109, 0.751384437084198, 0.6885764598846436, 0.34171855449676514, 0.6160555481910706, 0.16125458478927612, 0.9489563703536987, 0.1679406762123108, 0.20235764980316162, 0.881341278553009, 0.5121568441390991, 0.92611163854599, 0.4272810220718384, 0.2771685719490051, 0.9968149065971375, 0.15203642845153809, 0.4641759395599365, 0.413804292678833, 0.651200532913208, 0.32379841804504395, 0.6092511415481567, 0.8478893637657166, 0.9164987802505493, 0.7830571532249451, 0.8511074185371399, 0.15598517656326294, 0.37074369192123413, 0.4150542616844177, 0.5372101068496704, 0.14770209789276123, 0.23627835512161255, 0.06493926048278809, 0.09793734550476074, 0.8952625393867493, 0.3102150559425354, 0.43397819995880127, 0.5844067931175232, 0.3500853180885315, 0.22773021459579468, 0.05151098966598511, 0.6001721620559692, 0.3343912363052368, 0.46624046564102173, 0.549772322177887, 0.7372040748596191, 0.042217373847961426, 0.49783337116241455, 0.6191272735595703, 0.10424506664276123, 0.25230395793914795, 0.6359071135520935, 0.2743602991104126, 0.9432371258735657, 0.9513012170791626, 0.03317856788635254, 0.6779598593711853, 0.16297829151153564, 0.7508324384689331, 0.45239394903182983, 0.4571657180786133, 0.2818666100502014, 0.17718660831451416, 0.3443911671638489, 0.9507452845573425, 0.9657842516899109, 0.5738788843154907, 0.1653851866722107, 0.9145194292068481, 0.3741937279701233, 0.05001652240753174, 0.984059751033783, 0.6154639720916748, 0.6324515342712402, 0.08717000484466553, 0.894913911819458, 0.15595513582229614, 0.7157255411148071, 0.7270150780677795, 0.8960562944412231, 0.8680596947669983, 0.4046359062194824, 0.13201695680618286, 0.7696574926376343, 0.6563123464584351, 0.4042320251464844, 0.7971198558807373, 0.3859425187110901, 0.1174202561378479, 0.689612865447998, 0.3647807240486145, 0.4592200517654419, 0.05106186866760254, 0.6523975133895874, 0.2841695547103882, 0.07473695278167725, 0.5434084534645081, 0.20653265714645386, 0.31296950578689575, 0.3648838400840759, 0.5561838150024414, 0.34996211528778076, 0.4311484098434448, 0.7667861580848694, 0.8519712686538696, 0.9926041960716248, 0.11216002702713013, 0.632348895072937, 0.12260669469833374, 0.4656309485435486, 0.9705881476402283, 0.21109485626220703, 0.08676731586456299, 0.16977214813232422, 0.16263240575790405, 0.685645580291748, 0.14296847581863403, 0.6883099675178528, 0.046499550342559814, 0.026215791702270508, 0.2830098867416382, 0.5826435685157776, 0.5168894529342651, 0.21555876731872559, 0.6396673917770386, 0.5840371251106262, 0.8724454641342163, 0.13193941116333008, 0.6795164942741394, 0.15844780206680298, 0.27272599935531616, 0.6294320225715637, 0.35720330476760864, 0.8047906160354614, 0.392769455909729, 0.12731897830963135, 0.0994114875793457, 0.06120210886001587, 0.1784428358078003, 0.35258960723876953, 0.24094289541244507, 0.49197083711624146, 0.5611958503723145, 0.9082142114639282, 0.010551989078521729, 0.22507983446121216, 0.007577002048492432, 0.27637380361557007, 0.8734598159790039, 0.623768150806427, 0.10755449533462524, 0.4494497776031494, 0.06208443641662598, 0.5997845530509949, 0.8712562322616577, 0.4752771854400635, 0.3971136808395386, 0.01836305856704712, 0.4282991290092468, 0.40080422163009644, 0.757233738899231, 0.17318272590637207, 0.42513030767440796, 0.06316602230072021, 0.5725610256195068, 0.14671063423156738, 0.8214869499206543], [2, 0, 13, 12, 1, 0, 6, 3, 0, 4, 10, 0, 9, 3, 14, 13, 12, 0, 0, 11, 8, 14, 0, 8, 13, 6, 14, 9, 2, 10], [0.3678196668624878, 0.5077707171440125, 0.7048060894012451, 0.824059784412384, 0.5817335247993469, 0.1481603980064392, 0.7114542722702026, 0.46050769090652466, 0.2019960880279541, 0.47866880893707275, 0.6982521414756775, 0.8433347344398499, 0.3427056074142456, 0.182439923286438, 0.37574678659439087, 0.93959641456604, 0.5769227147102356, 0.5917701125144958, 0.30993229150772095, 0.26509368419647217, 0.3919970393180847, 0.04614824056625366, 0.09493398666381836, 0.8279429078102112, 0.29439830780029297, 0.42815113067626953, 0.4152073264122009, 0.33196568489074707, 0.24249297380447388, 0.9800586700439453]]} \ No newline at end of file diff --git a/examples/onnx/scatter_elements/network.onnx b/examples/onnx/scatter_elements/network.onnx new file mode 100644 index 000000000..34eeeb329 --- /dev/null +++ b/examples/onnx/scatter_elements/network.onnx @@ -0,0 +1,30 @@ +pytorch2.0.1:û +O +input +input1 +input2output/ScatterElements"ScatterElements* +axis  torch_jitZ% +input + +  +batch_size + +Z& +input1 + +  +batch_size + +Z& +input2 + +  +batch_size + +b& +output + +  +batch_size + +B \ No newline at end of file diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 57cd76380..ebd916d85 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -57,6 +57,10 @@ pub enum HybridOp { dim: usize, constant_idx: Option>, }, + ScatterElements { + dim: usize, + constant_idx: Option>, + }, } impl Op for HybridOp { @@ -64,6 +68,7 @@ impl Op for HybridOp { fn requires_homogenous_input_scales(&self) -> Vec { match self { HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1], + HybridOp::ScatterElements { .. } => vec![0, 2], _ => vec![], } } @@ -176,6 +181,20 @@ impl Op for HybridOp { (res.clone(), inter_equals) } } + HybridOp::ScatterElements { dim, constant_idx } => { + let src = inputs[2].clone().map(|x| felt_to_i128(x)); + if let Some(idx) = constant_idx { + log::debug!("idx: {}", idx.show()); + let res = tensor::ops::scatter(&x, idx, &src, *dim)?; + (res.clone(), vec![]) + } else { + let idx = inputs[1].clone().map(|x| felt_to_i128(x)); + let inter_equals: Vec> = + vec![Tensor::from(0..x.dims()[*dim] as i128)]; + let res = tensor::ops::scatter(&x, &idx.map(|x| x as usize), &src, *dim)?; + (res.clone(), inter_equals) + } + } HybridOp::MaxPool2d { padding, stride, @@ -253,6 +272,7 @@ impl Op for HybridOp { HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim), HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim), HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim), + HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim), HybridOp::OneHot { dim, num_classes } => { format!("ONEHOT (dim={}, num_classes={})", dim, num_classes) } @@ -281,6 +301,19 @@ impl Op for HybridOp { layouts::gather_elements(config, region, values[..].try_into()?, *dim)? } } + HybridOp::ScatterElements { dim, constant_idx } => { + if let Some(idx) = constant_idx { + tensor::ops::scatter( + values[0].get_inner_tensor()?, + idx, + values[1].get_inner_tensor()?, + *dim, + )? + .into() + } else { + layouts::scatter_elements(config, region, values[..].try_into()?, *dim)? + } + } HybridOp::MaxPool2d { padding, stride, @@ -395,6 +428,7 @@ impl Op for HybridOp { HybridOp::Gather { .. } | HybridOp::OneHot { .. } | HybridOp::GatherElements { .. } + | HybridOp::ScatterElements { .. } | HybridOp::Equals { .. } => { vec![LookupOp::KroneckerDelta] } diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 0af2b034f..bce8c3ef7 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -783,6 +783,52 @@ pub fn gather( Ok(output.into()) } +/// +pub fn scatter_elements( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 3], + dim: usize, +) -> Result, Box> { + // first create a claim + let (input, index, src) = (values[0].clone(), values[1].clone(), values[2].clone()); + + let mut index_usize = if !index.any_unknowns() { + index.get_int_evals()?.map(|x| x as usize) + } else { + Tensor::new(None, index.dims())? + }; + index_usize.reshape(index.dims()); + + let claimed_output = tensor::ops::scatter( + input.get_inner_tensor()?, + &index_usize, + src.get_inner_tensor()?, + dim, + )?; + + let assigned_claimed_output = region.assign(&config.inputs[0], &claimed_output.into())?; + let assigned_src = region.assign(&config.inputs[1], &src)?; + let assigned_index = region.assign(&config.output, &index)?; + + region.increment(std::cmp::max( + assigned_claimed_output.len(), + std::cmp::max(assigned_src.len(), assigned_index.len()), + )); + + // now assert that the claimed output gathered using the index is equal to the src + let gathered = gather_elements( + config, + region, + &[assigned_claimed_output.clone(), assigned_index], + dim, + )?; + + enforce_equality(config, region, &[gathered, assigned_src])?; + + Ok(assigned_claimed_output) +} + /// Gather accumulated layout pub fn gather_elements( config: &BaseConfig, diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index bcf5bae9a..670ff4741 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp}; #[cfg(not(target_arch = "wasm32"))] use tract_onnx::tract_core::ops::{ - array::{Gather, GatherElements, OneHot, Slice, Topk}, + array::{Gather, GatherElements, OneHot, ScatterElements, Slice, Topk}, change_axes::AxisOp, cnn::DeconvUnary, einsum::EinSum, @@ -281,6 +281,45 @@ pub fn new_op_from_onnx( num_classes, }) } + "ScatterElements" => { + if inputs.len() != 3 { + return Err(Box::new(GraphError::InvalidDims( + idx, + "scatter elements".to_string(), + ))); + }; + let op = load_op::(node.op(), idx, node.op().name().to_string())?; + let axis = op.axis; + + let mut op = + SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements { + dim: axis, + constant_idx: None, + }); + + // if param_visibility.is_public() { + if let Some(c) = inputs[1].opkind().get_mutable_constant() { + inputs[1].decrement_const(); + deleted_indices.push(1); + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements { + dim: axis, + constant_idx: Some(c.raw_values.map(|x| x as usize)), + }) + } + // } + + if inputs[1].opkind().is_input() { + inputs[1].replace_opkind(SupportedOp::Input(crate::circuit::ops::Input { + scale: 0, + datum_type: InputType::TDim, + })); + inputs[1].bump_scale(0); + } + + op + + // Extract the max value + } "GatherElements" => { if inputs.len() != 2 { return Err(Box::new(GraphError::InvalidDims( @@ -301,7 +340,7 @@ pub fn new_op_from_onnx( if let Some(c) = inputs[1].opkind().get_mutable_constant() { inputs[1].decrement_const(); deleted_indices.push(inputs.len() - 1); - op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::Gather { + op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements { dim: axis, constant_idx: Some(c.raw_values.map(|x| x as usize)), }) diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index f96315b15..300c62f9d 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1186,6 +1186,68 @@ pub fn gather( Ok(output) } +/// Scatters a tensor along a dimension. +/// # Arguments +/// * `input` - Tensor +/// * `dim` - Dimension to scatter along +/// * `index` - Tensor of indices to scatter +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::tensor::ops::scatter; +/// let x = Tensor::::new( +/// Some(&[1.0, 2.0, 3.0, 4.0]), +/// &[2, 2], +/// ).unwrap(); +/// let src = Tensor::::new( +/// Some(&[5.0, 6.0, 7.0, 8.0]), +/// &[2, 2], +/// ).unwrap(); +/// let index = Tensor::::new( +/// Some(&[0, 0, 1, 0]), +/// &[2, 2], +/// ).unwrap(); +/// let result = scatter(&x, &index, &src, 0).unwrap(); +/// let expected = Tensor::::new(Some(&[5.0, 8.0, 7.0, 4.0]), &[2, 2]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn scatter( + input: &Tensor, + index: &Tensor, + src: &Tensor, + dim: usize, +) -> Result, TensorError> { + println!("scatter"); + // Writes all values from the tensor src into self at the indices specified in the index tensor. + // For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim. + assert_eq!(index.dims(), src.dims()); + // Calculate the output tensor size + let src_size = src.dims().to_vec(); + + // For a 3-D tensor, self is updated as: + // self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + // self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + // self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + let mut output = input.clone(); + + let cartesian_coord = src_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + cartesian_coord.iter().for_each(|coord| { + let mut new_coord = coord.clone(); + let index_val = index.get(&coord); + new_coord[dim] = index_val; + let val = src.get(&coord); + output.set(&new_coord, val); + }); + + Ok(output) +} + /// Gathers a tensor along a dimension. /// # Arguments /// * `input` - Tensor diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index a35d0bd1e..675e79452 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -179,7 +179,7 @@ mod native_tests { "1l_prelu", ]; - const TESTS: [&str; 58] = [ + const TESTS: [&str; 59] = [ "1l_mlp", "1l_slice", "1l_concat", @@ -241,6 +241,7 @@ mod native_tests { "less", "xgboost_reg", "1l_powf", + "scatter_elements", ]; const WASM_TESTS: [&str; 48] = [ @@ -470,7 +471,7 @@ mod native_tests { - seq!(N in 0..=57 { + seq!(N in 0..=58 { #(#[test_case(TESTS[N])])* fn model_serialization_(test: &str) { let test_dir = TempDir::new(test).unwrap(); From 23641716089ebd1f83a49ccf26c16dcaca800482 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Mon, 25 Sep 2023 20:54:18 +0100 Subject: [PATCH 2/4] clean up --- src/circuit/ops/layouts.rs | 19 +++++++++++-------- src/tensor/ops.rs | 1 - 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index bce8c3ef7..2c4411138 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -791,7 +791,12 @@ pub fn scatter_elements( dim: usize, ) -> Result, Box> { // first create a claim - let (input, index, src) = (values[0].clone(), values[1].clone(), values[2].clone()); + let (mut input, index, src) = (values[0].clone(), values[1].clone(), values[2].clone()); + + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + } + let assigned_src = region.assign(&config.inputs[1], &src)?; let mut index_usize = if !index.any_unknowns() { index.get_int_evals()?.map(|x| x as usize) @@ -800,27 +805,25 @@ pub fn scatter_elements( }; index_usize.reshape(index.dims()); + // this will get copy constrained with the input when assigned let claimed_output = tensor::ops::scatter( input.get_inner_tensor()?, &index_usize, - src.get_inner_tensor()?, + assigned_src.get_inner_tensor()?, dim, )?; - - let assigned_claimed_output = region.assign(&config.inputs[0], &claimed_output.into())?; - let assigned_src = region.assign(&config.inputs[1], &src)?; - let assigned_index = region.assign(&config.output, &index)?; + let assigned_claimed_output = region.assign(&config.output, &claimed_output.into())?; region.increment(std::cmp::max( assigned_claimed_output.len(), - std::cmp::max(assigned_src.len(), assigned_index.len()), + assigned_src.len(), )); // now assert that the claimed output gathered using the index is equal to the src let gathered = gather_elements( config, region, - &[assigned_claimed_output.clone(), assigned_index], + &[assigned_claimed_output.clone(), index], dim, )?; diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 300c62f9d..05fdd30cc 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1217,7 +1217,6 @@ pub fn scatter( src: &Tensor, dim: usize, ) -> Result, TensorError> { - println!("scatter"); // Writes all values from the tensor src into self at the indices specified in the index tensor. // For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim. assert_eq!(index.dims(), src.dims()); From 5d83e71d779748a5dae82246f7c919d15634f29a Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Tue, 26 Sep 2023 18:46:41 +0100 Subject: [PATCH 3/4] patch scatter --- src/circuit/ops/layouts.rs | 184 +++++++++++++++++++++++++------------ src/tensor/mod.rs | 24 ++++- src/tensor/val.rs | 2 +- 3 files changed, 149 insertions(+), 61 deletions(-) diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 2c4411138..30b10fd8d 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -553,7 +553,7 @@ fn select( input.flatten(); // assert we have a single index - // assert_eq!(index.dims().iter().product::(), 1); + assert_eq!(index.dims().iter().product::(), 1); assert!(dim_indices.all_prev_assigned() || region.is_dummy()); let is_assigned = !input.any_unknowns() && !index.any_unknowns(); @@ -715,13 +715,19 @@ pub fn gather( let (mut input, mut index) = (values[0].clone(), values[1].clone()); index.flatten(); + let mut assigned_len = vec![]; if !input.all_prev_assigned() { input = region.assign(&config.inputs[0], &input)?; + assigned_len.push(input.len()); } if !index.all_prev_assigned() { index = region.assign(&config.inputs[1], &index)?; + assigned_len.push(index.len()); + } + + if !assigned_len.is_empty() { + region.increment(assigned_len.iter().max().unwrap().clone()); } - region.increment(std::cmp::max(input.len(), index.len())); // Calculate the output tensor size let input_dims = input.dims(); @@ -783,55 +789,6 @@ pub fn gather( Ok(output.into()) } -/// -pub fn scatter_elements( - config: &BaseConfig, - region: &mut RegionCtx, - values: &[ValTensor; 3], - dim: usize, -) -> Result, Box> { - // first create a claim - let (mut input, index, src) = (values[0].clone(), values[1].clone(), values[2].clone()); - - if !input.all_prev_assigned() { - input = region.assign(&config.inputs[0], &input)?; - } - let assigned_src = region.assign(&config.inputs[1], &src)?; - - let mut index_usize = if !index.any_unknowns() { - index.get_int_evals()?.map(|x| x as usize) - } else { - Tensor::new(None, index.dims())? - }; - index_usize.reshape(index.dims()); - - // this will get copy constrained with the input when assigned - let claimed_output = tensor::ops::scatter( - input.get_inner_tensor()?, - &index_usize, - assigned_src.get_inner_tensor()?, - dim, - )?; - let assigned_claimed_output = region.assign(&config.output, &claimed_output.into())?; - - region.increment(std::cmp::max( - assigned_claimed_output.len(), - assigned_src.len(), - )); - - // now assert that the claimed output gathered using the index is equal to the src - let gathered = gather_elements( - config, - region, - &[assigned_claimed_output.clone(), index], - dim, - )?; - - enforce_equality(config, region, &[gathered, assigned_src])?; - - Ok(assigned_claimed_output) -} - /// Gather accumulated layout pub fn gather_elements( config: &BaseConfig, @@ -907,6 +864,115 @@ pub fn gather_elements( Ok(output.into()) } +/// Gather accumulated layout +pub fn scatter_elements( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 3], + dim: usize, +) -> Result, Box> { + let (mut input, mut index, mut src) = (values[0].clone(), values[1].clone(), values[2].clone()); + + assert_eq!(input.dims().len(), index.dims().len()); + + let mut assigned_len = vec![]; + + if !input.all_prev_assigned() { + input = region.assign(&config.inputs[0], &input)?; + assigned_len.push(input.len()); + } + if !index.all_prev_assigned() { + index = region.assign(&config.inputs[1], &index)?; + assigned_len.push(index.len()); + } + if !src.all_prev_assigned() { + src = region.assign(&config.output, &src)?; + assigned_len.push(src.len()); + } + + if !assigned_len.is_empty() { + region.increment(*assigned_len.iter().max().unwrap()); + } + + // Calculate the output tensor size + let input_dim = input.dims()[dim]; + let output_size = index.dims().to_vec(); + + // these will be assigned as constants + let mut indices = Tensor::from((0..input_dim as u64).map(|x| F::from(x))); + indices.set_visibility(&crate::graph::Visibility::Fixed); + let indices = region.assign(&config.inputs[1], &indices.into())?; + region.increment(indices.len()); + + // Allocate memory for the output tensor + let cartesian_coord = output_size + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut unit = Tensor::from(vec![F::from(1)].into_iter()); + unit.set_visibility(&crate::graph::Visibility::Fixed); + let unit: ValTensor = unit.into(); + region.assign(&config.inputs[1], &unit).unwrap(); + region.increment(1); + + let mut output = Tensor::new(None, &output_size)?; + + let mut inner_loop_function = |i: usize, region: &mut RegionCtx<'_, F>| -> () { + let coord = cartesian_coord[i].clone(); + let index_val = index.get_inner_tensor().unwrap().get(&coord); + + let src_val = src.get_inner_tensor().unwrap().get(&coord); + let src_valtensor: ValTensor = Tensor::from([src_val.clone()].into_iter()).into(); + + let mut slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + slice[dim] = 0..input_dim; + + let mut sliced_input = input.get_slice(&slice).unwrap(); + sliced_input.flatten(); + + let index_valtensor: ValTensor = Tensor::from([index_val.clone()].into_iter()).into(); + + let mask = equals(config, region, &[index_valtensor, indices.clone()]).unwrap(); + + let one_minus_mask = + pairwise(config, region, &[unit.clone(), mask.clone()], BaseOp::Sub).unwrap(); + + let pairwise_prod = pairwise(config, region, &[src_valtensor, mask], BaseOp::Mult).unwrap(); + let pairwise_prod_2 = pairwise( + config, + region, + &[sliced_input, one_minus_mask], + BaseOp::Mult, + ) + .unwrap(); + + let res = pairwise( + config, + region, + &[pairwise_prod, pairwise_prod_2], + BaseOp::Add, + ) + .unwrap(); + + let input_cartesian_coord = slice.into_iter().multi_cartesian_product(); + + let mutable_input_inner = input.get_inner_tensor_mut().unwrap(); + + for (i, r) in res.get_inner_tensor().unwrap().iter().enumerate() { + let coord = input_cartesian_coord.clone().nth(i).unwrap(); + *mutable_input_inner.get_mutable_index(&coord) = r.clone(); + } + }; + + output.iter_mut().enumerate().for_each(|(i, o)| { + *o = inner_loop_function(i, region); + }); + + Ok(input) +} + /// sum accumulated layout pub fn sum( config: &BaseConfig, @@ -1799,20 +1865,21 @@ pub fn conv( if !input.all_prev_assigned() { input = region.assign(&config.inputs[0], &input)?; + region.increment(input.len()); } if input.dims().len() == 1 { return op(config, region, &[input]); } - region.increment(input.len()); - // Calculate the output tensor size let input_dims = input.dims(); diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 4a8bcea53..c0d5a6118 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -567,7 +567,6 @@ impl Tensor { } /// Get a slice from the Tensor. - /// /// ``` /// use ezkl::tensor::Tensor; /// let mut a = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); @@ -640,6 +639,29 @@ impl Tensor { index } + /// Get the array index from rows / columns indices. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let a = Tensor::::new(None, &[3, 3, 3]).unwrap(); + /// + /// assert_eq!(a.get_mutable_index(&[2, 2, 2]), 26); + /// assert_eq!(a.get_mutable_index(&[1, 2, 2]), 17); + /// assert_eq!(a.get_mutable_index(&[1, 2, 0]), 15); + /// assert_eq!(a.get_mutable_index(&[1, 0, 1]), 10); + /// ``` + pub fn get_mutable_index(&mut self, indices: &[usize]) -> &mut T { + assert_eq!(self.dims.len(), indices.len()); + let mut index = 0; + let mut d = 1; + for i in (0..indices.len()).rev() { + assert!(self.dims[i] > indices[i]); + index += indices[i] * d; + d *= self.dims[i]; + } + &mut self[index] + } + /// Duplicates every nth element /// /// ``` diff --git a/src/tensor/val.rs b/src/tensor/val.rs index c34dcf342..51f1708cb 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -378,7 +378,7 @@ impl ValTensor { Ok(slice) } - /// Calls `get_slice` on the inner tensor. + /// Calls `get_single_elem` on the inner tensor. pub fn get_single_elem(&self, index: usize) -> Result, Box> { let slice = match self { ValTensor::Value { From e4a6133633209ce8a0c37705baf83103ca5a9e21 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Tue, 26 Sep 2023 19:11:13 +0100 Subject: [PATCH 4/4] get_mut --- src/circuit/ops/layouts.rs | 2 +- src/tensor/mod.rs | 44 ++++++++++++++++++-------------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 30b10fd8d..dd07008d3 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -962,7 +962,7 @@ pub fn scatter_elements( for (i, r) in res.get_inner_tensor().unwrap().iter().enumerate() { let coord = input_cartesian_coord.clone().nth(i).unwrap(); - *mutable_input_inner.get_mutable_index(&coord) = r.clone(); + *mutable_input_inner.get_mut(&coord) = r.clone(); } }; diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index c0d5a6118..6774dc12f 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -538,6 +538,27 @@ impl Tensor { self[index].clone() } + /// Get a mutable array index from rows / columns indices. + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// let mut a = Tensor::::new(None, &[2, 3, 5]).unwrap(); + /// + /// a[1*15 + 1*5 + 1] = 5; + /// assert_eq!(a.get(&[1, 1, 1]), 5); + /// ``` + pub fn get_mut(&mut self, indices: &[usize]) -> &mut T { + assert_eq!(self.dims.len(), indices.len()); + let mut index = 0; + let mut d = 1; + for i in (0..indices.len()).rev() { + assert!(self.dims[i] > indices[i]); + index += indices[i] * d; + d *= self.dims[i]; + } + &mut self[index] + } + /// Get a single value from the Tensor. /// /// ``` @@ -639,29 +660,6 @@ impl Tensor { index } - /// Get the array index from rows / columns indices. - /// - /// ``` - /// use ezkl::tensor::Tensor; - /// let a = Tensor::::new(None, &[3, 3, 3]).unwrap(); - /// - /// assert_eq!(a.get_mutable_index(&[2, 2, 2]), 26); - /// assert_eq!(a.get_mutable_index(&[1, 2, 2]), 17); - /// assert_eq!(a.get_mutable_index(&[1, 2, 0]), 15); - /// assert_eq!(a.get_mutable_index(&[1, 0, 1]), 10); - /// ``` - pub fn get_mutable_index(&mut self, indices: &[usize]) -> &mut T { - assert_eq!(self.dims.len(), indices.len()); - let mut index = 0; - let mut d = 1; - for i in (0..indices.len()).rev() { - assert!(self.dims[i] > indices[i]); - index += indices[i] * d; - d *= self.dims[i]; - } - &mut self[index] - } - /// Duplicates every nth element /// /// ```