From b09c509744619fc19ed3226ac65948b5e4c76504 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Thu, 12 Dec 2024 16:22:04 +0800 Subject: [PATCH 1/3] implement reverse --- src/lib/validate-input.js | 3 ++ src/reverse.js | 34 ++++++++++++++++++++ test/reverse_test.js | 65 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 src/reverse.js create mode 100644 test/reverse_test.js diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 756256c..111543a 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -472,6 +472,9 @@ export function validatePool2dParams(input, _, {roundingType = 'floor'}) { } export function validateReduceParams(input, {axes}) { + if (axes === undefined) { + return; + } if (axes.length > input.rank) { throw new Error(`The length ${axes.length} of axes is bigger` + `than input rank ${input.rank}.`); diff --git a/src/reverse.js b/src/reverse.js new file mode 100644 index 0000000..234229e --- /dev/null +++ b/src/reverse.js @@ -0,0 +1,34 @@ +'use strict'; + +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {validateReduceParams} from './lib/validate-input.js'; + +/** + * Reverse input along the given axes. + * @param {Tensor} input + * @param {MLReverseOptions} [options] + * @return {Tensor} + */ +export function reverse(input, {axes}) { + // reuse validateReduceParams to check parameters of reverse + validateReduceParams(input, {axes}); + + const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + const inputShape = input.shape; + const outputShape = inputShape.slice(); + const outputSize = sizeOfShape(outputShape); + const output = new Tensor(outputShape); + + for (let outputIndex = 0; outputIndex < outputSize; ++outputIndex) { + const outputLocation = output.locationFromIndex(outputIndex); + const selectedInputLocation = outputLocation.slice(); + for (const axis of inputAxes) { + const index = selectedInputLocation[axis]; + selectedInputLocation[axis] = inputShape[axis] - index - 1; + } + const selectedInputValue = input.getValueByLocation(selectedInputLocation); + output.setValueByLocation(outputLocation, selectedInputValue); + } + + return output; +} diff --git a/test/reverse_test.js b/test/reverse_test.js new file mode 100644 index 0000000..cb6af2e --- /dev/null +++ b/test/reverse_test.js @@ -0,0 +1,65 @@ +'use strict'; + +import {reverse} from '../src/reverse.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test reverse', function() { + function testReverse(inputShape, inputValue, expected, options = {}) { + const inputTensor = new Tensor(inputShape, inputValue); + const outputTensor = reverse(inputTensor, options); + utils.checkShape(outputTensor, inputShape); + utils.checkValue(outputTensor, expected); + } + + it('reverse 0D scalar default options', function() { + testReverse([], [2], [2]); + }); + + it('reverse 4D default options', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]); + }); + + it('reverse 2D with axes=[0]', function() { + testReverse( + [2, 2], + [1, 2, 3, 4], + [3, 4, 1, 2], + { + axes: [0], + }); + }); + + it('reverse 4D with axes=[3]', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15], + { + axes: [3], + }); + }); + + it('reverse 4D with axes=[1, 2]', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [7, 8, 5, 6, 3, 4, 1, 2, 15, 16, 13, 14, 11, 12, 9, 10], + { + axes: [1, 2], + }); + }); + + it('reverse 4D with axes=[3, 1, 0, 2]', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + { + axes: [3, 1, 0, 2], + }); + }); +}); From 602a411c04a89d47e129a67237b155174d1a8d45 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 17 Dec 2024 09:25:33 +0800 Subject: [PATCH 2/3] rename common function as validateAxes --- src/lib/validate-input.js | 2 +- src/reduce.js | 4 ++-- src/reverse.js | 5 ++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 111543a..4d4a464 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -471,7 +471,7 @@ export function validatePool2dParams(input, _, {roundingType = 'floor'}) { } } -export function validateReduceParams(input, {axes}) { +export function validateAxes(input, {axes}) { if (axes === undefined) { return; } diff --git a/src/reduce.js b/src/reduce.js index fcb39c9..b2563c2 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -4,10 +4,10 @@ import {pow} from './binary.js'; import {squeeze} from './reshape.js'; import {abs, exp, log} from './unary.js'; import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; -import {validateReduceParams} from './lib/validate-input.js'; +import {validateAxes} from './lib/validate-input.js'; export function selectValuesToReduce(input, axes, inputLocation) { - validateReduceParams(input, {axes}); + validateAxes(input, {axes}); const outputShape = input.shape.slice(); for (let i = 0; i < axes.length; ++i) { diff --git a/src/reverse.js b/src/reverse.js index 234229e..5d006b8 100644 --- a/src/reverse.js +++ b/src/reverse.js @@ -1,7 +1,7 @@ 'use strict'; import {sizeOfShape, Tensor} from './lib/tensor.js'; -import {validateReduceParams} from './lib/validate-input.js'; +import {validateAxes} from './lib/validate-input.js'; /** * Reverse input along the given axes. @@ -10,8 +10,7 @@ import {validateReduceParams} from './lib/validate-input.js'; * @return {Tensor} */ export function reverse(input, {axes}) { - // reuse validateReduceParams to check parameters of reverse - validateReduceParams(input, {axes}); + validateAxes(input, {axes}); const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); const inputShape = input.shape; From 90203e6bc4b2d8876cb6d5e43aa484274cff848b Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 17 Dec 2024 10:47:53 +0800 Subject: [PATCH 3/3] update validateAxes function --- src/lib/validate-input.js | 2 +- src/reduce.js | 2 +- src/reverse.js | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 4d4a464..531a314 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -471,7 +471,7 @@ export function validatePool2dParams(input, _, {roundingType = 'floor'}) { } } -export function validateAxes(input, {axes}) { +export function validateAxes(input, axes) { if (axes === undefined) { return; } diff --git a/src/reduce.js b/src/reduce.js index b2563c2..8c72aae 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -7,7 +7,7 @@ import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; import {validateAxes} from './lib/validate-input.js'; export function selectValuesToReduce(input, axes, inputLocation) { - validateAxes(input, {axes}); + validateAxes(input, axes); const outputShape = input.shape.slice(); for (let i = 0; i < axes.length; ++i) { diff --git a/src/reverse.js b/src/reverse.js index 5d006b8..788674e 100644 --- a/src/reverse.js +++ b/src/reverse.js @@ -10,7 +10,7 @@ import {validateAxes} from './lib/validate-input.js'; * @return {Tensor} */ export function reverse(input, {axes}) { - validateAxes(input, {axes}); + validateAxes(input, axes); const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); const inputShape = input.shape;