diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 756256c..531a314 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -471,7 +471,10 @@ export function validatePool2dParams(input, _, {roundingType = 'floor'}) { } } -export function validateReduceParams(input, {axes}) { +export function validateAxes(input, axes) { + if (axes === undefined) { + return; + } if (axes.length > input.rank) { throw new Error(`The length ${axes.length} of axes is bigger` + `than input rank ${input.rank}.`); diff --git a/src/reduce.js b/src/reduce.js index fcb39c9..8c72aae 100644 --- a/src/reduce.js +++ b/src/reduce.js @@ -4,10 +4,10 @@ import {pow} from './binary.js'; import {squeeze} from './reshape.js'; import {abs, exp, log} from './unary.js'; import {sizeOfShape, Scalar, Tensor} from './lib/tensor.js'; -import {validateReduceParams} from './lib/validate-input.js'; +import {validateAxes} from './lib/validate-input.js'; export function selectValuesToReduce(input, axes, inputLocation) { - validateReduceParams(input, {axes}); + validateAxes(input, axes); const outputShape = input.shape.slice(); for (let i = 0; i < axes.length; ++i) { diff --git a/src/reverse.js b/src/reverse.js new file mode 100644 index 0000000..788674e --- /dev/null +++ b/src/reverse.js @@ -0,0 +1,33 @@ +'use strict'; + +import {sizeOfShape, Tensor} from './lib/tensor.js'; +import {validateAxes} from './lib/validate-input.js'; + +/** + * Reverse input along the given axes. + * @param {Tensor} input + * @param {MLReverseOptions} [options] + * @return {Tensor} + */ +export function reverse(input, {axes}) { + validateAxes(input, axes); + + const inputAxes = axes ?? new Array(input.rank).fill(0).map((_, i) => i); + const inputShape = input.shape; + const outputShape = inputShape.slice(); + const outputSize = sizeOfShape(outputShape); + const output = new Tensor(outputShape); + + for (let outputIndex = 0; outputIndex < outputSize; ++outputIndex) { + const outputLocation = output.locationFromIndex(outputIndex); + const selectedInputLocation = outputLocation.slice(); + for (const axis of inputAxes) { + const index = selectedInputLocation[axis]; + selectedInputLocation[axis] = inputShape[axis] - index - 1; + } + const selectedInputValue = input.getValueByLocation(selectedInputLocation); + output.setValueByLocation(outputLocation, selectedInputValue); + } + + return output; +} diff --git a/test/reverse_test.js b/test/reverse_test.js new file mode 100644 index 0000000..cb6af2e --- /dev/null +++ b/test/reverse_test.js @@ -0,0 +1,65 @@ +'use strict'; + +import {reverse} from '../src/reverse.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test reverse', function() { + function testReverse(inputShape, inputValue, expected, options = {}) { + const inputTensor = new Tensor(inputShape, inputValue); + const outputTensor = reverse(inputTensor, options); + utils.checkShape(outputTensor, inputShape); + utils.checkValue(outputTensor, expected); + } + + it('reverse 0D scalar default options', function() { + testReverse([], [2], [2]); + }); + + it('reverse 4D default options', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]); + }); + + it('reverse 2D with axes=[0]', function() { + testReverse( + [2, 2], + [1, 2, 3, 4], + [3, 4, 1, 2], + { + axes: [0], + }); + }); + + it('reverse 4D with axes=[3]', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15], + { + axes: [3], + }); + }); + + it('reverse 4D with axes=[1, 2]', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [7, 8, 5, 6, 3, 4, 1, 2, 15, 16, 13, 14, 11, 12, 9, 10], + { + axes: [1, 2], + }); + }); + + it('reverse 4D with axes=[3, 1, 0, 2]', function() { + testReverse( + [2, 2, 2, 2], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + { + axes: [3, 1, 0, 2], + }); + }); +});