Skip to content

Commit

Permalink
[webnn] More refined code to reuse for other operation tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Nov 30, 2022
1 parent fd68dea commit 7767126
Show file tree
Hide file tree
Showing 3 changed files with 1,819 additions and 730 deletions.
43 changes: 8 additions & 35 deletions webnn/concat.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,15 @@

// https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-concat

const concatTests = () => {
const resources = loadTestData('/webnn/resources/test_data/concat.json');
const tests = resources.tests;
const inputsData = resources.inputsData;
const expectedData = resources.expectedData;
const targetTests = [];
for (const test of tests) {
const inputShapeValues = [];
const inputShapes = test.inputs.shape;
const inputDataSource = test.inputs.data;
const expectedDataSource = test.expected.data;
let position = 0;
for (const shape of inputShapes) {
const size = sizeOfShape(shape);
inputShapeValues.push({shape, data: inputsData[inputDataSource].slice(position, position + size)});
position += size;
}
const expected = {shape: test.expected.shape, data: {outputOperand: expectedData[expectedDataSource]}};
targetTests.push({name: test.name, operandType: test.type, inputShapeValues, axis: test.axis, expected});
}
return targetTests;
};

const buildGraph = (builder, resources) => {
const inputShapeValues = resources.inputShapeValues;
const operandType = resources.operandType;
const TestTypedArray = TypedArrayDict[operandType];
const buildConcat = (builder, resources) => {
// MLOperand concat(sequence<MLOperand> inputs, long axis);
const namedOutputOperand = {};
const inputOperands = [];
const inputs = {};
for (let i = 0; i < inputShapeValues.length; i++) {
inputOperands.push(builder.input('input' + i, {type: operandType, dimensions: inputShapeValues[i].shape}));
inputs['input' + i] = new TestTypedArray(inputShapeValues[i].data);
for (let input of resources.inputs) {
inputOperands.push(builder.input(input.name, {type: input.type, dimensions: input.shape}));
}
const outputOperand = builder.concat(inputOperands, resources.axis);
const outputs = {'outputOperand': new TestTypedArray(sizeOfShape(resources.expected.shape))};
return [{outputOperand}, inputs, outputs];
namedOutputOperand[resources.expected.name] = builder.concat(inputOperands, resources.axis);
return namedOutputOperand;
};

testWebNNOperation('concat', concatTests(), buildGraph);
testWebNNOperation('concat', '/webnn/resources/test_data/concat.json', buildConcat);
Loading

0 comments on commit 7767126

Please sign in to comment.