From fe295962849d4a81c24cd58aa21bf0ce3def995f Mon Sep 17 00:00:00 2001 From: BruceDai Date: Wed, 18 Dec 2024 10:10:14 +0800 Subject: [PATCH] support negative indices --- src/gather_nd.js | 3 ++- test/gather_nd_test.js | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/gather_nd.js b/src/gather_nd.js index aafbd92..88c1df3 100644 --- a/src/gather_nd.js +++ b/src/gather_nd.js @@ -35,7 +35,8 @@ export function gatherND(input, indices) { const indicesLocation = indices.locationFromIndex(indicesIndex); const indicesArray = []; for (let i = 0; i < lastIndicesSize; i++) { - indicesArray.push(indices.getValueByIndex(indicesIndex + i)); + const indicesValue = indices.getValueByIndex(indicesIndex + i); + indicesArray.push(indicesValue >= 0 ? indicesValue : inputShape[i] + indicesValue); } for (let tmpIndex = 0; tmpIndex < tmpTotal; ++tmpIndex) { const tmpLocation = tmp.locationFromIndex(tmpIndex); diff --git a/test/gather_nd_test.js b/test/gather_nd_test.js index ec16378..7e552b1 100644 --- a/test/gather_nd_test.js +++ b/test/gather_nd_test.js @@ -85,6 +85,28 @@ describe('test gatherND', function() { testGatherND(input, indices, expected); }); + it('gatherND 2D inputs when 3D negative indices has leading dimensions', function() { + const input = { + shape: [2, 2], + data: [ + 1, 2, + 3, 4, + ], + }; + const indices = { + shape: [2, 1, 1], + data: [-1, -2], + }; + const expected = { + shape: [2, 1, 2], + data: [ + 3, 4, + 1, 2, + ], + }; + testGatherND(input, indices, expected); + }); + it('gatherND 3D inputs by 2D indices', function() { // Refer to Example 3 on https://onnx.ai/onnx/operators/onnx__GatherND.html const input = { @@ -113,6 +135,7 @@ describe('test gatherND', function() { testGatherND(input, indices, expected); }); + it('gatherND 3D inputs by 3D indices', function() { // Refer to Example 4 on https://onnx.ai/onnx/operators/onnx__GatherND.html const input = {