From 22aa524577f1468e324fac663a2d46172b9b75b9 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 17 Dec 2024 20:31:27 +0800 Subject: [PATCH 1/3] implement gatherND --- src/gather_nd.js | 50 +++++++++++ src/lib/validate-input.js | 36 ++++++++ test/gather_nd_test.js | 184 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 270 insertions(+) create mode 100644 src/gather_nd.js create mode 100644 test/gather_nd_test.js diff --git a/src/gather_nd.js b/src/gather_nd.js new file mode 100644 index 0000000..aafbd92 --- /dev/null +++ b/src/gather_nd.js @@ -0,0 +1,50 @@ +'use strict'; + +import {Tensor, sizeOfShape} from './lib/tensor.js'; +import {validateGatherNDParams} from './lib/validate-input.js'; + +/** + * Gathers values using multidimensional indices. + * @param {Tensor} input + * @param {Tensor} indices + * @return {Tensor} + */ +export function gatherND(input, indices) { + validateGatherNDParams(input, indices); + + const inputRank = input.rank; + const inputShape = input.shape; + const indicesRank = indices.rank; + const indicesShape = indices.shape; + const lastIndicesSize = indicesShape[indicesRank - 1]; + const tmpShape = inputShape.slice(lastIndicesSize); + + // Refer to https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/ + + // operation-specs/movement/gather-nd-8.html + let outputShape = indicesShape.slice(0, indicesRank - 1); + if (lastIndicesSize !== inputRank) { + outputShape = outputShape.concat(tmpShape); + } + + const output = new Tensor(outputShape); + const indicesTotal = sizeOfShape(indicesShape); + const tmp = new Tensor(tmpShape); + const tmpTotal = sizeOfShape(tmpShape); + + for (let indicesIndex = 0; indicesIndex < indicesTotal; indicesIndex += lastIndicesSize) { + const indicesLocation = indices.locationFromIndex(indicesIndex); + const indicesArray = []; + for (let i = 0; i < lastIndicesSize; i++) { + indicesArray.push(indices.getValueByIndex(indicesIndex + i)); + } + for (let tmpIndex = 0; tmpIndex < tmpTotal; ++tmpIndex) { + const tmpLocation = tmp.locationFromIndex(tmpIndex); + const outputLocation = indicesLocation.slice(0, indicesRank - 1).concat(tmpLocation); + const inputValue = input.getValueByLocation(indicesArray.concat(tmpLocation)); + // output[i_0, ..., i_{K-2},:,...,:] = data[indices[i_0, ..., i_{K-2}],:,...,:] + output.setValueByLocation(outputLocation, inputValue); + } + } + + return output; +} diff --git a/src/lib/validate-input.js b/src/lib/validate-input.js index 78d8847..df555c1 100644 --- a/src/lib/validate-input.js +++ b/src/lib/validate-input.js @@ -623,6 +623,42 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) { } } +export function validateGatherNDParams(input, indices) { + const inputRank = input.rank; + const inputShape = input.shape; + const indicesRank = indices.rank; + const indicesShape = indices.shape; + + // Refer to https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/ + + // operation-specs/movement/gather-nd-8.html + if (inputRank < 1) { + throw new Error(`The input rank shouldn't be less than 1.`); + } + + if (indicesRank < 1) { + throw new Error(`The indices rank shouldn't be less than 1.`); + } + + const lastIndicesSize = indicesShape[indicesRank - 1]; + if (lastIndicesSize < 1 || lastIndicesSize > inputRank) { + throw new Error(`The indices.shape[-1] should be in the range [1, ${inputRank}].`); + } + + const indicesTotal = sizeOfShape(indicesShape); + for (let indicesIndex = 0; indicesIndex < indicesTotal; indicesIndex += lastIndicesSize) { + for (let i = 0; i < lastIndicesSize; i++) { + const indicesValue = indices.getValueByIndex(indicesIndex + i); + const maxSize = inputShape[i]; + if (!Number.isInteger(indicesValue) || + indicesValue < -maxSize || + indicesValue > maxSize - 1) { + throw new Error(`Invalid indices value - it should be an integer in the interval ` + + `[${-maxSize}, ${maxSize - 1}]`); + } + } + } +} + export function validateScatterElementsParams(input, indices, updates, {axis = 0} = {}) { const inputRank = input.rank; const indicesRank = indices.rank; diff --git a/test/gather_nd_test.js b/test/gather_nd_test.js new file mode 100644 index 0000000..ec16378 --- /dev/null +++ b/test/gather_nd_test.js @@ -0,0 +1,184 @@ +'use strict'; + +import {gatherND} from '../src/gather_nd.js'; +import {Tensor} from '../src/lib/tensor.js'; +import * as utils from './utils.js'; + +describe('test gatherND', function() { + function testGatherND(input, indices, expected) { + const inputTensor = new Tensor(input.shape, input.data); + const indicesTensor = new Tensor(indices.shape, indices.data); + const outputTensor = gatherND(inputTensor, indicesTensor); + utils.checkShape(outputTensor, expected.shape); + utils.checkValue(outputTensor, expected.data); + } + + it('gatherND with elements from 2D inputs by 2D indices', function() { + // Refer to Example 1 on https://docs.openvino.ai/2024/documentation/openvino-ir-format/ + + // operation-sets/operation-specs/movement/gather-nd-8.html + const input = { + shape: [2, 2], + data: [ + 1, 2, + 3, 4, + ], + }; + const indices = { + shape: [2, 2], + data: [ + 0, 0, + 1, 0, + ], + }; + const expected = { + shape: [2], + data: [1, 3], + }; + testGatherND(input, indices, expected); + }); + + it('gatherND with slices from 2D inputs by 2D indices', function() { + // Refer to Example 2 on https://docs.openvino.ai/2024/documentation/openvino-ir-format/ + + // operation-sets/operation-specs/movement/gather-nd-8.html + const input = { + shape: [2, 2], + data: [ + 1, 2, + 3, 4, + ], + }; + const indices = { + shape: [2, 1], + data: [1, 0], + }; + const expected = { + shape: [2, 2], + data: [ + 3, 4, + 1, 2, + ], + }; + testGatherND(input, indices, expected); + }); + + it('gatherND 2D inputs when 3D indices has leading dimensions', function() { + // Refer to Example 3 on https://docs.openvino.ai/2024/documentation/openvino-ir-format/ + + // operation-sets/operation-specs/movement/gather-nd-8.html + const input = { + shape: [2, 2], + data: [ + 1, 2, + 3, 4, + ], + }; + const indices = { + shape: [2, 1, 1], + data: [1, 0], + }; + const expected = { + shape: [2, 1, 2], + data: [ + 3, 4, + 1, 2, + ], + }; + testGatherND(input, indices, expected); + }); + + it('gatherND 3D inputs by 2D indices', function() { + // Refer to Example 3 on https://onnx.ai/onnx/operators/onnx__GatherND.html + const input = { + shape: [2, 2, 2], + data: [ + 0, 1, + 2, 3, + 4, 5, + 6, 7, + ], + }; + const indices = { + shape: [2, 2], + data: [ + 0, 1, + 1, 0, + ], + }; + const expected = { + shape: [2, 2], + data: [ + 2, 3, + 4, 5, + ], + }; + testGatherND(input, indices, expected); + }); + + it('gatherND 3D inputs by 3D indices', function() { + // Refer to Example 4 on https://onnx.ai/onnx/operators/onnx__GatherND.html + const input = { + shape: [2, 2, 2], + data: [ + 0, 1, + 2, 3, + 4, 5, + 6, 7, + ], + }; + const indices = { + shape: [2, 1, 2], + data: [ + 0, 1, + 1, 0, + ], + }; + const expected = { + shape: [2, 1, 2], + data: [ + 2, 3, + 4, 5, + ], + }; + testGatherND(input, indices, expected); + }); + + it('gatherND 1D inputs by 1D indices', function() { + const input = { + shape: [4], + data: [1, 2, 3, 4], + }; + const indices = { + shape: [1], + data: [1], + }; + const expected = { + shape: [], + data: [2], + }; + testGatherND(input, indices, expected); + }); + + it('gatherND 4D inputs by 1D indices', function() { + const input = { + shape: [2, 2, 2, 2], + data: [ + 1, 2, + 3, 4, + 5, 6, + 7, 8, + 9, 10, + 11, 12, + 13, 14, + 15, 16, + ], + }; + const indices = { + shape: [3], + data: [1, 0, 0], + }; + const expected = { + shape: [2], + data: [9, 10], + }; + testGatherND(input, indices, expected); + }); +}); From fe295962849d4a81c24cd58aa21bf0ce3def995f Mon Sep 17 00:00:00 2001 From: BruceDai Date: Wed, 18 Dec 2024 10:10:14 +0800 Subject: [PATCH 2/3] support negative indices --- src/gather_nd.js | 3 ++- test/gather_nd_test.js | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/gather_nd.js b/src/gather_nd.js index aafbd92..88c1df3 100644 --- a/src/gather_nd.js +++ b/src/gather_nd.js @@ -35,7 +35,8 @@ export function gatherND(input, indices) { const indicesLocation = indices.locationFromIndex(indicesIndex); const indicesArray = []; for (let i = 0; i < lastIndicesSize; i++) { - indicesArray.push(indices.getValueByIndex(indicesIndex + i)); + const indicesValue = indices.getValueByIndex(indicesIndex + i); + indicesArray.push(indicesValue >= 0 ? indicesValue : inputShape[i] + indicesValue); } for (let tmpIndex = 0; tmpIndex < tmpTotal; ++tmpIndex) { const tmpLocation = tmp.locationFromIndex(tmpIndex); diff --git a/test/gather_nd_test.js b/test/gather_nd_test.js index ec16378..7e552b1 100644 --- a/test/gather_nd_test.js +++ b/test/gather_nd_test.js @@ -85,6 +85,28 @@ describe('test gatherND', function() { testGatherND(input, indices, expected); }); + it('gatherND 2D inputs when 3D negative indices has leading dimensions', function() { + const input = { + shape: [2, 2], + data: [ + 1, 2, + 3, 4, + ], + }; + const indices = { + shape: [2, 1, 1], + data: [-1, -2], + }; + const expected = { + shape: [2, 1, 2], + data: [ + 3, 4, + 1, 2, + ], + }; + testGatherND(input, indices, expected); + }); + it('gatherND 3D inputs by 2D indices', function() { // Refer to Example 3 on https://onnx.ai/onnx/operators/onnx__GatherND.html const input = { @@ -113,6 +135,7 @@ describe('test gatherND', function() { testGatherND(input, indices, expected); }); + it('gatherND 3D inputs by 3D indices', function() { // Refer to Example 4 on https://onnx.ai/onnx/operators/onnx__GatherND.html const input = { From 4a5a0f31addf966cb881991f80c0a150fee54230 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Wed, 18 Dec 2024 10:21:18 +0800 Subject: [PATCH 3/3] Trigger CI for MacOS