Skip to content

Commit

Permalink
Merge pull request #64 from BruceDai/add_triangular
Browse files Browse the repository at this point in the history
Implement triangular
  • Loading branch information
huningxin authored Jan 2, 2024
2 parents 64c784e + 584cc0a commit 32c2d3c
Show file tree
Hide file tree
Showing 3 changed files with 507 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,13 @@ export function validateGatherParams(input, indices, {axis = 0} = {}) {
}
}
}

export function validateTriangularParams(input, {diagonal = 0} = {}) {
const inputRank = input.rank;
if (inputRank < 2) {
throw new Error('The input should be at least a 2-D tensor.');
}
if (!Number.isInteger(diagonal)) {
throw new Error(`The diagonal should be an integer.`);
}
}
39 changes: 39 additions & 0 deletions src/triangular.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
'use strict';

import {Tensor, sizeOfShape} from './lib/tensor.js';
import {validateTriangularParams} from './lib/validate-input.js';

/**
* Get retained boolean flag.
* @param {Array} location
* @param {Boolean} upper
* @param {Number} diagonal
* @return {Boolean}
*/
function isRetainedValue(location, upper, diagonal) {
const i = location[location.length - 2];
const j = location[location.length - 1];
return upper ? j >= i + diagonal : j <= i + diagonal;
}

/**
* Given a 2-D tensor (matrix), return a 2-D tensor containing either the upper or lower triangular
* part of the input tensor.
* @param {Tensor} input
* @param {MLTriangularOptions} [options]
* @return {Tensor}
*/
export function triangular(input, {upper = true, diagonal = 0} = {}) {
validateTriangularParams(...arguments);
const shapeOutput = input.shape.slice();
const output = new Tensor(shapeOutput);

for (let outputIndex = 0; outputIndex < sizeOfShape(shapeOutput); ++outputIndex) {
const outputLoc = output.locationFromIndex(outputIndex);
const retainedFlag = isRetainedValue(outputLoc, upper, diagonal);
const inputValue = retainedFlag ? input.getValueByLocation(outputLoc) : 0;
output.setValueByLocation(outputLoc, inputValue);
}

return output;
}
Loading

0 comments on commit 32c2d3c

Please sign in to comment.