From 97b9744009b00751dcac2c1a9eebcb5cd59ce2a2 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Fri, 13 Dec 2024 13:23:46 +0800 Subject: [PATCH] support strides options for slice op --- src/lib/validate-input.js | 14 +++++++++++++- src/slice.js | 14 +++++++++++--- test/slice_test.js | 40 +++++++++++++++++++++++++++++++++++---- 3 files changed, 60 insertions(+), 8 deletions(-) diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 756256c..78d8847 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -483,7 +483,7 @@ export function validateReduceParams(input, {axes}) { } } -export function validateSliceParams(input, starts, sizes) { +export function validateSliceParams(input, starts, sizes, {strides} = {}) { const rank = input.rank; if (starts.length !== rank) { throw new Error(`The length ${starts.length} of starts is not equal to the length ` + @@ -493,6 +493,11 @@ export function validateSliceParams(input, starts, sizes) { throw new Error(`The length ${sizes.length} of sizes is not equal` + ` to the length ${rank} of input.`); } + if (strides !== undefined && strides.length !== rank) { + throw new Error(`The length ${strides.length} of strides is not equal` + + ` to the length ${rank} of input.`); + } + for (let i = 0; i < rank; ++i) { const size = input.shape[i]; const start = starts[i]; @@ -507,6 +512,13 @@ export function validateSliceParams(input, starts, sizes) { if (!Number.isInteger(sliceSize) || sliceSize <= 0) { throw new Error(`Invalid sizes value ${sliceSize} - it should be an unsigned integer.`); } + if (strides !== undefined) { + const stride = strides[i]; + if (!Number.isInteger(stride) || stride < 1) { + throw new Error(`Invalid strides value ${stride} - it should be an unsigned integer ` + + 'greater than or equal to 1.'); + } + } if (start + sliceSize > size) { throw new Error(`Invalid sizes value ${sliceSize} - the sum of the start ${start} ` + `plus the size ${sliceSize} is greater than the dimensional size ${size}`); diff --git a/src/slice.js b/src/slice.js index 8848eca..698b8fd 100644 --- a/src/slice.js +++ b/src/slice.js @@ -8,17 +8,25 @@ import {validateSliceParams} from './lib/validate-input.js'; * @param {Tensor} input * @param {Array} starts * @param {Array} sizes + * @param {MLSliceOptions} options * @return {Tensor} */ -export function slice(input, starts, sizes) { +export function slice(input, starts, sizes, {strides} = {}) { validateSliceParams(...arguments); - const outputShape = sizes; + strides = strides ?? new Array(input.rank).fill(1); + const outputShape = input.shape.slice(); + + for (let dimensionIndex = 0; dimensionIndex < input.rank; ++dimensionIndex) { + outputShape[dimensionIndex] = Math.floor(sizes[dimensionIndex] / strides[dimensionIndex]) + + Number(sizes[dimensionIndex] % strides[dimensionIndex]); + } + const output = new Tensor(outputShape); for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) { const loc = output.locationFromIndex(outputIndex); const selectedInputLoc = loc.slice(); for (let i = 0; i < loc.length; ++i) { - selectedInputLoc[i] = loc[i] + starts[i]; + selectedInputLoc[i] = starts[i] + loc[i] * strides[i]; } const inputValue = input.getValueByLocation(selectedInputLoc); output.setValueByIndex(outputIndex, inputValue); diff --git a/test/slice_test.js b/test/slice_test.js index 9466079..3f2732b 100644 --- a/test/slice_test.js +++ b/test/slice_test.js @@ -5,14 +5,14 @@ import {Tensor} from '../src/lib/tensor.js'; import * as utils from './utils.js'; describe('test slice', function() { - function testSlice(inputShape, inputData, starts, sizes, expectedShape, expected) { + function testSlice(inputShape, inputData, starts, sizes, expectedShape, expected, options = {}) { const input = new Tensor(inputShape, inputData); - const output = slice(input, starts, sizes); + const output = slice(input, starts, sizes, options); utils.checkShape(output, expectedShape); utils.checkValue(output, expected); } - it('slice 1D', function() { + it('slice 1D default options', function() { const inputShape = [4]; const inputData = [0, 1, 2, 3]; const starts = [3]; @@ -22,7 +22,7 @@ describe('test slice', function() { testSlice(inputShape, inputData, starts, sizes, expectedShape, expected); }); - it('slice 3D', function() { + it('slice 3D default options', function() { const inputShape = [3, 4, 5]; const inputData = [ 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, @@ -54,4 +54,36 @@ describe('test slice', function() { ]; testSlice(inputShape, inputData, starts, sizes, expectedShape, expected); }); + + it('slice 2D with strides=[1, 5]', function() { + const inputShape = [3, 20]; + const inputData = [ + 1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02, + -3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02, + 1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00, + -2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02, + 6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01, + + 1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01, + 3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01, + 2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00, + -6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01, + 4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00, + + -8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02, + 9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01, + -4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01, + -3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01, + -4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01, + ]; + const starts = [1, 4]; + const sizes = [2, 10]; + const strides = [1, 3]; + const expectedShape = [2, 4]; + const expected = [ + 3.5038650e-01, -8.7751287e-01, 8.6336482e-01, 1.3729302e-03, + 9.1792661e-01, -5.1338053e-01, 4.2538303e-01, 7.8374326e-01, + ]; + testSlice(inputShape, inputData, starts, sizes, expectedShape, expected, {strides}); + }); });