Skip to content

Commit

Permalink
implement reverse
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Dec 17, 2024
1 parent 9499088 commit b09c509
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ export function validatePool2dParams(input, _, {roundingType = 'floor'}) {
}

export function validateReduceParams(input, {axes}) {
if (axes === undefined) {
return;
}
if (axes.length > input.rank) {
throw new Error(`The length ${axes.length} of axes is bigger` +
`than input rank ${input.rank}.`);
Expand Down
34 changes: 34 additions & 0 deletions src/reverse.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
'use strict';

import {sizeOfShape, Tensor} from './lib/tensor.js';
import {validateReduceParams} from './lib/validate-input.js';

/**
* Reverse input along the given axes.
* @param {Tensor} input
* @param {MLReverseOptions} [options]
* @return {Tensor}
*/
export function reverse(input, {axes}) {
// reuse validateReduceParams to check parameters of reverse
validateReduceParams(input, {axes});

const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i);
const inputShape = input.shape;
const outputShape = inputShape.slice();
const outputSize = sizeOfShape(outputShape);
const output = new Tensor(outputShape);

for (let outputIndex = 0; outputIndex < outputSize; ++outputIndex) {
const outputLocation = output.locationFromIndex(outputIndex);
const selectedInputLocation = outputLocation.slice();
for (const axis of inputAxes) {
const index = selectedInputLocation[axis];
selectedInputLocation[axis] = inputShape[axis] - index - 1;
}
const selectedInputValue = input.getValueByLocation(selectedInputLocation);
output.setValueByLocation(outputLocation, selectedInputValue);
}

return output;
}
65 changes: 65 additions & 0 deletions test/reverse_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
'use strict';

import {reverse} from '../src/reverse.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test reverse', function() {
function testReverse(inputShape, inputValue, expected, options = {}) {
const inputTensor = new Tensor(inputShape, inputValue);
const outputTensor = reverse(inputTensor, options);
utils.checkShape(outputTensor, inputShape);
utils.checkValue(outputTensor, expected);
}

it('reverse 0D scalar default options', function() {
testReverse([], [2], [2]);
});

it('reverse 4D default options', function() {
testReverse(
[2, 2, 2, 2],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]);
});

it('reverse 2D with axes=[0]', function() {
testReverse(
[2, 2],
[1, 2, 3, 4],
[3, 4, 1, 2],
{
axes: [0],
});
});

it('reverse 4D with axes=[3]', function() {
testReverse(
[2, 2, 2, 2],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15],
{
axes: [3],
});
});

it('reverse 4D with axes=[1, 2]', function() {
testReverse(
[2, 2, 2, 2],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[7, 8, 5, 6, 3, 4, 1, 2, 15, 16, 13, 14, 11, 12, 9, 10],
{
axes: [1, 2],
});
});

it('reverse 4D with axes=[3, 1, 0, 2]', function() {
testReverse(
[2, 2, 2, 2],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
[16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
{
axes: [3, 1, 0, 2],
});
});
});

0 comments on commit b09c509

Please sign in to comment.