diff --git a/nnotepad/js/nnotepad.js b/nnotepad/js/nnotepad.js index ba236f0f..a3674353 100644 --- a/nnotepad/js/nnotepad.js +++ b/nnotepad/js/nnotepad.js @@ -31,6 +31,10 @@ export class ComputeError extends Error { // General WebNN Utilities // ============================================================ +const kArgTypeOperandList = 1; +const kArgTypeNonOperand = 2; +const kArgTypeOperand = 3; + class WebNNUtil { static bufferForOperand(operand) { const size = [...operand.shape()].reduce((a, b) => a * b, 1); @@ -60,21 +64,22 @@ class WebNNUtil { throw new Error(`Unsupported dataType ${type}`); } - static isNonOperandArg(name, index) { + static argumentType(name, index) { return ({ - concat: [0, 1], - expand: [1], - gru: [3, 4], - gruCell: [4], - lstm: [3, 4], - lstmCell: [5], - pad: [1, 2], - reshape: [1], - slice: [1, 2], - softmax: [1], // TODO: Distinguish overloads - split: [1], + concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand}, + expand: {1: kArgTypeNonOperand}, + gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand}, + gruCell: {4: kArgTypeNonOperand}, + lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand}, + lstmCell: {5: kArgTypeNonOperand}, + pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand}, + reshape: {1: kArgTypeNonOperand}, + slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand}, + softmax: {1: kArgTypeNonOperand}, + split: {1: kArgTypeNonOperand}, })[name] - ?.includes(index); + ?.[index] || + kArgTypeOperand; } } @@ -379,7 +384,7 @@ export class NNotepad { } throw new Error(`unexpected line type: ${line.type}`); } - function serializeExpr(expr, nonOperand = false) { + function serializeExpr(expr, argumentType = kArgTypeOperand) { if (expr.op) { if (expr.lhs) { return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${ @@ -394,11 +399,21 @@ export class NNotepad { case 'boolean': return String(expr.value); case 'number': - return nonOperand ? Util.stringify(expr.value) : - serializeScalar(expr.value, expr.dataType); + switch (argumentType) { + case kArgTypeNonOperand: + return Util.stringify(expr.value); + default: + return serializeScalar(expr.value, expr.dataType); + } case 'array': - return nonOperand ? serializeArray(expr.value) : - serializeTensor(expr.value, expr.dataType); + switch (argumentType) { + case kArgTypeNonOperand: + return serializeArray(expr.value, kArgTypeNonOperand); + case kArgTypeOperandList: + return serializeArray(expr.value, kArgTypeOperand); + default: + return serializeTensor(expr.value, expr.dataType); + } case 'dict': return serializeDict(expr.dict); case 'identifier': @@ -414,7 +429,7 @@ export class NNotepad { .map((k) => { const v = dict[k]; k = Util.stringify(k); - return `${k}: ${serializeExpr(v, true)}`; + return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`; }) .join(', ') + '}'; @@ -465,8 +480,10 @@ export class NNotepad { elements.map((n) => Util.stringifyNumber(n, dataType)).join(',')}]))`; } - function serializeArray(array) { - return '[' + array.map((expr) => serializeExpr(expr)).join(', ') + ']'; + function serializeArray(array, argumentType) { + return '[' + + array.map((expr) => serializeExpr(expr, argumentType)).join(', ') + + ']'; } function serializeCall(name, args) { @@ -506,8 +523,8 @@ export class NNotepad { return `_.${name}(${ args.map( - (arg, index) => serializeExpr( - arg, WebNNUtil.isNonOperandArg(name, index))) + (arg, index) => + serializeExpr(arg, WebNNUtil.argumentType(name, index))) .join(', ')})`; } } diff --git a/nnotepad/js/tests.js b/nnotepad/js/tests.js index 5572b5c7..c789dd20 100644 --- a/nnotepad/js/tests.js +++ b/nnotepad/js/tests.js @@ -157,7 +157,7 @@ document.addEventListener('DOMContentLoaded', async (e) => { {dataType: 'float32', shape: [2], buffer: [3, 4]}, ]); - Harness.section('Multiple input tensors'); + Harness.section('Non-operand arguments: array of operands'); await test( `A = [1,2] B = [3,4] concat([A,B], 0)`, {dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]}); @@ -165,6 +165,19 @@ document.addEventListener('DOMContentLoaded', async (e) => { `concat([identity([1,2]),identity([3,4])], 0)`, {dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]}); + Harness.section('Non-operand arguments: array of numbers'); + await test( + `T = [[1,2,3],[4,5,6]] reshape(T, [1, 3, 2, 1])`, + {dataType: 'float32', shape: [1, 3, 2, 1], buffer: [1, 2, 3, 4, 5, 6]}); + await test( + `expand([1], [2, 2])`, + {dataType: 'float32', shape: [2, 2], buffer: [1, 1, 1, 1]}); + + Harness.section('Non-operand arguments: simple numbers'); + await test( + `softmax([1], 0)`, + {dataType: 'float32', shape: [1], buffer: [1]}); + Harness.section('Regression tests'); await test( `concat([[1,2],[3,4]], 0)`,