Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement gatherND #119

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/gather_nd.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateGatherNDParams} from './lib/validate-input.js';

/**
* Gathers values using multidimensional indices.
* @param {Tensor} input
* @param {Tensor} indices
* @return {Tensor}
*/
export function gatherND(input, indices) {
validateGatherNDParams(input, indices);

const inputRank = input.rank;
const inputShape = input.shape;
const indicesRank = indices.rank;
const indicesShape = indices.shape;
const lastIndicesSize = indicesShape[indicesRank - 1];
const tmpShape = inputShape.slice(lastIndicesSize);

// Refer to https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/ +
// operation-specs/movement/gather-nd-8.html
let outputShape = indicesShape.slice(0, indicesRank - 1);
if (lastIndicesSize !== inputRank) {
outputShape = outputShape.concat(tmpShape);
}

const output = new Tensor(outputShape);
const indicesTotal = sizeOfShape(indicesShape);
const tmp = new Tensor(tmpShape);
const tmpTotal = sizeOfShape(tmpShape);

for (let indicesIndex = 0; indicesIndex < indicesTotal; indicesIndex += lastIndicesSize) {
const indicesLocation = indices.locationFromIndex(indicesIndex);
const indicesArray = [];
for (let i = 0; i < lastIndicesSize; i++) {
const indicesValue = indices.getValueByIndex(indicesIndex + i);
indicesArray.push(indicesValue >= 0 ? indicesValue : inputShape[i] + indicesValue);
}
for (let tmpIndex = 0; tmpIndex < tmpTotal; ++tmpIndex) {
const tmpLocation = tmp.locationFromIndex(tmpIndex);
const outputLocation = indicesLocation.slice(0, indicesRank - 1).concat(tmpLocation);
const inputValue = input.getValueByLocation(indicesArray.concat(tmpLocation));
// output[i_0, ..., i_{K-2},:,...,:] = data[indices[i_0, ..., i_{K-2}],:,...,:]
output.setValueByLocation(outputLocation, inputValue);
}
}

return output;
}
36 changes: 36 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,42 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) {
}
}

export function validateGatherNDParams(input, indices) {
const inputRank = input.rank;
const inputShape = input.shape;
const indicesRank = indices.rank;
const indicesShape = indices.shape;

// Refer to https://docs.openvino.ai/2024/documentation/openvino-ir-format/operation-sets/ +
// operation-specs/movement/gather-nd-8.html
if (inputRank < 1) {
throw new Error(`The input rank shouldn't be less than 1.`);
}

if (indicesRank < 1) {
throw new Error(`The indices rank shouldn't be less than 1.`);
}

const lastIndicesSize = indicesShape[indicesRank - 1];
if (lastIndicesSize < 1 || lastIndicesSize > inputRank) {
throw new Error(`The indices.shape[-1] should be in the range [1, ${inputRank}].`);
}

const indicesTotal = sizeOfShape(indicesShape);
for (let indicesIndex = 0; indicesIndex < indicesTotal; indicesIndex += lastIndicesSize) {
for (let i = 0; i < lastIndicesSize; i++) {
const indicesValue = indices.getValueByIndex(indicesIndex + i);
const maxSize = inputShape[i];
if (!Number.isInteger(indicesValue) ||
indicesValue < -maxSize ||
indicesValue > maxSize - 1) {
throw new Error(`Invalid indices value - it should be an integer in the interval ` +
`[${-maxSize}, ${maxSize - 1}]`);
}
}
}
}

export function validateScatterElementsParams(input, indices, updates, {axis = 0} = {}) {
const inputRank = input.rank;
const indicesRank = indices.rank;
Expand Down
207 changes: 207 additions & 0 deletions test/gather_nd_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
'use strict';

import {gatherND} from '../src/gather_nd.js';
import {Tensor} from '../src/lib/tensor.js';
import * as utils from './utils.js';

describe('test gatherND', function() {
function testGatherND(input, indices, expected) {
const inputTensor = new Tensor(input.shape, input.data);
const indicesTensor = new Tensor(indices.shape, indices.data);
const outputTensor = gatherND(inputTensor, indicesTensor);
utils.checkShape(outputTensor, expected.shape);
utils.checkValue(outputTensor, expected.data);
}

it('gatherND with elements from 2D inputs by 2D indices', function() {
// Refer to Example 1 on https://docs.openvino.ai/2024/documentation/openvino-ir-format/ +
// operation-sets/operation-specs/movement/gather-nd-8.html
const input = {
shape: [2, 2],
data: [
1, 2,
3, 4,
],
};
const indices = {
shape: [2, 2],
data: [
0, 0,
1, 0,
],
};
const expected = {
shape: [2],
data: [1, 3],
};
testGatherND(input, indices, expected);
});

it('gatherND with slices from 2D inputs by 2D indices', function() {
// Refer to Example 2 on https://docs.openvino.ai/2024/documentation/openvino-ir-format/ +
// operation-sets/operation-specs/movement/gather-nd-8.html
const input = {
shape: [2, 2],
data: [
1, 2,
3, 4,
],
};
const indices = {
shape: [2, 1],
data: [1, 0],
};
const expected = {
shape: [2, 2],
data: [
3, 4,
1, 2,
],
};
testGatherND(input, indices, expected);
});

it('gatherND 2D inputs when 3D indices has leading dimensions', function() {
// Refer to Example 3 on https://docs.openvino.ai/2024/documentation/openvino-ir-format/ +
// operation-sets/operation-specs/movement/gather-nd-8.html
const input = {
shape: [2, 2],
data: [
1, 2,
3, 4,
],
};
const indices = {
shape: [2, 1, 1],
data: [1, 0],
};
const expected = {
shape: [2, 1, 2],
data: [
3, 4,
1, 2,
],
};
testGatherND(input, indices, expected);
});

it('gatherND 2D inputs when 3D negative indices has leading dimensions', function() {
const input = {
shape: [2, 2],
data: [
1, 2,
3, 4,
],
};
const indices = {
shape: [2, 1, 1],
data: [-1, -2],
};
const expected = {
shape: [2, 1, 2],
data: [
3, 4,
1, 2,
],
};
testGatherND(input, indices, expected);
});

it('gatherND 3D inputs by 2D indices', function() {
// Refer to Example 3 on https://onnx.ai/onnx/operators/onnx__GatherND.html
const input = {
shape: [2, 2, 2],
data: [
0, 1,
2, 3,
4, 5,
6, 7,
],
};
const indices = {
shape: [2, 2],
data: [
0, 1,
1, 0,
],
};
const expected = {
shape: [2, 2],
data: [
2, 3,
4, 5,
],
};
testGatherND(input, indices, expected);
});


it('gatherND 3D inputs by 3D indices', function() {
// Refer to Example 4 on https://onnx.ai/onnx/operators/onnx__GatherND.html
const input = {
shape: [2, 2, 2],
data: [
0, 1,
2, 3,
4, 5,
6, 7,
],
};
const indices = {
shape: [2, 1, 2],
data: [
0, 1,
1, 0,
],
};
const expected = {
shape: [2, 1, 2],
data: [
2, 3,
4, 5,
],
};
testGatherND(input, indices, expected);
});

it('gatherND 1D inputs by 1D indices', function() {
const input = {
shape: [4],
data: [1, 2, 3, 4],
};
const indices = {
shape: [1],
data: [1],
};
const expected = {
shape: [],
data: [2],
};
testGatherND(input, indices, expected);
});

it('gatherND 4D inputs by 1D indices', function() {
const input = {
shape: [2, 2, 2, 2],
data: [
1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, 12,
13, 14,
15, 16,
],
};
const indices = {
shape: [3],
data: [1, 0, 0],
};
const expected = {
shape: [2],
data: [9, 10],
};
testGatherND(input, indices, expected);
});
});
Loading