Skip to content

Commit

Permalink
support strides options for slice op
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Dec 13, 2024
1 parent 4febe57 commit 97b9744
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
14 changes: 13 additions & 1 deletion src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ export function validateReduceParams(input, {axes}) {
}
}

export function validateSliceParams(input, starts, sizes) {
export function validateSliceParams(input, starts, sizes, {strides} = {}) {
const rank = input.rank;
if (starts.length !== rank) {
throw new Error(`The length ${starts.length} of starts is not equal to the length ` +
Expand All @@ -493,6 +493,11 @@ export function validateSliceParams(input, starts, sizes) {
throw new Error(`The length ${sizes.length} of sizes is not equal` +
` to the length ${rank} of input.`);
}
if (strides !== undefined && strides.length !== rank) {
throw new Error(`The length ${strides.length} of strides is not equal` +
` to the length ${rank} of input.`);
}

for (let i = 0; i < rank; ++i) {
const size = input.shape[i];
const start = starts[i];
Expand All @@ -507,6 +512,13 @@ export function validateSliceParams(input, starts, sizes) {
if (!Number.isInteger(sliceSize) || sliceSize <= 0) {
throw new Error(`Invalid sizes value ${sliceSize} - it should be an unsigned integer.`);
}
if (strides !== undefined) {
const stride = strides[i];
if (!Number.isInteger(stride) || stride < 1) {
throw new Error(`Invalid strides value ${stride} - it should be an unsigned integer ` +
'greater than or equal to 1.');
}
}
if (start + sliceSize > size) {
throw new Error(`Invalid sizes value ${sliceSize} - the sum of the start ${start} ` +
`plus the size ${sliceSize} is greater than the dimensional size ${size}`);
Expand Down
14 changes: 11 additions & 3 deletions src/slice.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,25 @@ import {validateSliceParams} from './lib/validate-input.js';
* @param {Tensor} input
* @param {Array} starts
* @param {Array} sizes
* @param {MLSliceOptions} options
* @return {Tensor}
*/
export function slice(input, starts, sizes) {
export function slice(input, starts, sizes, {strides} = {}) {
validateSliceParams(...arguments);
const outputShape = sizes;
strides = strides ?? new Array(input.rank).fill(1);
const outputShape = input.shape.slice();

for (let dimensionIndex = 0; dimensionIndex < input.rank; ++dimensionIndex) {
outputShape[dimensionIndex] = Math.floor(sizes[dimensionIndex] / strides[dimensionIndex]) +
Number(sizes[dimensionIndex] % strides[dimensionIndex]);
}

const output = new Tensor(outputShape);
for (let outputIndex = 0; outputIndex < sizeOfShape(outputShape); ++outputIndex) {
const loc = output.locationFromIndex(outputIndex);
const selectedInputLoc = loc.slice();
for (let i = 0; i < loc.length; ++i) {
selectedInputLoc[i] = loc[i] + starts[i];
selectedInputLoc[i] = starts[i] + loc[i] * strides[i];
}
const inputValue = input.getValueByLocation(selectedInputLoc);
output.setValueByIndex(outputIndex, inputValue);
Expand Down
40 changes: 36 additions & 4 deletions test/slice_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test slice', function() {
function testSlice(inputShape, inputData, starts, sizes, expectedShape, expected) {
function testSlice(inputShape, inputData, starts, sizes, expectedShape, expected, options = {}) {
const input = new Tensor(inputShape, inputData);
const output = slice(input, starts, sizes);
const output = slice(input, starts, sizes, options);
utils.checkShape(output, expectedShape);
utils.checkValue(output, expected);
}

it('slice 1D', function() {
it('slice 1D default options', function() {
const inputShape = [4];
const inputData = [0, 1, 2, 3];
const starts = [3];
Expand All @@ -22,7 +22,7 @@ describe('test slice', function() {
testSlice(inputShape, inputData, starts, sizes, expectedShape, expected);
});

it('slice 3D', function() {
it('slice 3D default options', function() {
const inputShape = [3, 4, 5];
const inputData = [
1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02,
Expand Down Expand Up @@ -54,4 +54,36 @@ describe('test slice', function() {
];
testSlice(inputShape, inputData, starts, sizes, expectedShape, expected);
});

it('slice 2D with strides=[1, 5]', function() {
const inputShape = [3, 20];
const inputData = [
1.3165863e+00, 4.1239005e-02, 4.6697399e-01, -6.6145003e-02,
-3.7128052e-01, -1.0660021e+00, 7.5784922e-01, 3.5759725e-02,
1.9211160e+00, -8.1603736e-01, 1.1800343e-01, -1.8293047e+00,
-2.1316205e-01, -3.6369815e-01, 6.4205879e-01, 7.1544610e-02,
6.8498695e-01, 1.0001093e+00, -5.6261641e-01, -7.3343945e-01,

1.6827687e+00, 1.2653192e+00, 5.8872145e-01, 3.1535852e-01,
3.5038650e-01, 3.5865438e-01, -3.6469769e-01, -8.7751287e-01,
2.7995768e-01, -1.6042528e+00, 8.6336482e-01, -1.7991974e+00,
-6.8652731e-01, 1.3729302e-03, -7.7775210e-01, 1.0199220e-01,
4.2299256e-01, 1.1432177e-01, -5.0116669e-02, 1.5525131e+00,

-8.7060851e-01, 4.5739245e-01, 1.3543987e-01, -1.5927458e-02,
9.1792661e-01, -4.5001405e-01, 1.9954188e-01, -5.1338053e-01,
-4.1026011e-01, -1.2718531e+00, 4.2538303e-01, -1.5449624e-01,
-3.4380481e-01, 7.8374326e-01, 1.7837452e+00, 9.6105379e-01,
-4.8783422e-01, -9.4987392e-01, -8.8750905e-01, -9.8019439e-01,
];
const starts = [1, 4];
const sizes = [2, 10];
const strides = [1, 3];
const expectedShape = [2, 4];
const expected = [
3.5038650e-01, -8.7751287e-01, 8.6336482e-01, 1.3729302e-03,
9.1792661e-01, -5.1338053e-01, 4.2538303e-01, 7.8374326e-01,
];
testSlice(inputShape, inputData, starts, sizes, expectedShape, expected, {strides});
});
});

0 comments on commit 97b9744

Please sign in to comment.