From 458deccd6a2f79e03c769a4edcd91676a3c49ca6 Mon Sep 17 00:00:00 2001 From: Joshua Bell Date: Tue, 23 Jul 2024 19:00:58 -0700 Subject: [PATCH] [NNotepad] Fix serialization of non-tensor arrays (#264) Most MLGraphBuilder method arguments are operands, but sometimes they're arrays or numbers. There was a mechanism in place to determine (via a table) whether a given method's argument was an operator or not, but that wasn't enough to distinguish an array of numbers (like reshape() needs) from an array of operands (which concat() needs). Rework the lookup from returning a boolean to returning an argument type, and adjust serialization appropriately. Fixes #263 --- nnotepad/js/nnotepad.js | 63 ++++++++++++++++++++++++++--------------- nnotepad/js/tests.js | 15 +++++++++- 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/nnotepad/js/nnotepad.js b/nnotepad/js/nnotepad.js index ba236f0f..a3674353 100644 --- a/nnotepad/js/nnotepad.js +++ b/nnotepad/js/nnotepad.js @@ -31,6 +31,10 @@ export class ComputeError extends Error { // General WebNN Utilities // ============================================================ +const kArgTypeOperandList = 1; +const kArgTypeNonOperand = 2; +const kArgTypeOperand = 3; + class WebNNUtil { static bufferForOperand(operand) { const size = [...operand.shape()].reduce((a, b) => a * b, 1); @@ -60,21 +64,22 @@ class WebNNUtil { throw new Error(`Unsupported dataType ${type}`); } - static isNonOperandArg(name, index) { + static argumentType(name, index) { return ({ - concat: [0, 1], - expand: [1], - gru: [3, 4], - gruCell: [4], - lstm: [3, 4], - lstmCell: [5], - pad: [1, 2], - reshape: [1], - slice: [1, 2], - softmax: [1], // TODO: Distinguish overloads - split: [1], + concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand}, + expand: {1: kArgTypeNonOperand}, + gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand}, + gruCell: {4: kArgTypeNonOperand}, + lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand}, + lstmCell: {5: kArgTypeNonOperand}, + pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand}, + reshape: {1: kArgTypeNonOperand}, + slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand}, + softmax: {1: kArgTypeNonOperand}, + split: {1: kArgTypeNonOperand}, })[name] - ?.includes(index); + ?.[index] || + kArgTypeOperand; } } @@ -379,7 +384,7 @@ export class NNotepad { } throw new Error(`unexpected line type: ${line.type}`); } - function serializeExpr(expr, nonOperand = false) { + function serializeExpr(expr, argumentType = kArgTypeOperand) { if (expr.op) { if (expr.lhs) { return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${ @@ -394,11 +399,21 @@ export class NNotepad { case 'boolean': return String(expr.value); case 'number': - return nonOperand ? Util.stringify(expr.value) : - serializeScalar(expr.value, expr.dataType); + switch (argumentType) { + case kArgTypeNonOperand: + return Util.stringify(expr.value); + default: + return serializeScalar(expr.value, expr.dataType); + } case 'array': - return nonOperand ? serializeArray(expr.value) : - serializeTensor(expr.value, expr.dataType); + switch (argumentType) { + case kArgTypeNonOperand: + return serializeArray(expr.value, kArgTypeNonOperand); + case kArgTypeOperandList: + return serializeArray(expr.value, kArgTypeOperand); + default: + return serializeTensor(expr.value, expr.dataType); + } case 'dict': return serializeDict(expr.dict); case 'identifier': @@ -414,7 +429,7 @@ export class NNotepad { .map((k) => { const v = dict[k]; k = Util.stringify(k); - return `${k}: ${serializeExpr(v, true)}`; + return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`; }) .join(', ') + '}'; @@ -465,8 +480,10 @@ export class NNotepad { elements.map((n) => Util.stringifyNumber(n, dataType)).join(',')}]))`; } - function serializeArray(array) { - return '[' + array.map((expr) => serializeExpr(expr)).join(', ') + ']'; + function serializeArray(array, argumentType) { + return '[' + + array.map((expr) => serializeExpr(expr, argumentType)).join(', ') + + ']'; } function serializeCall(name, args) { @@ -506,8 +523,8 @@ export class NNotepad { return `_.${name}(${ args.map( - (arg, index) => serializeExpr( - arg, WebNNUtil.isNonOperandArg(name, index))) + (arg, index) => + serializeExpr(arg, WebNNUtil.argumentType(name, index))) .join(', ')})`; } } diff --git a/nnotepad/js/tests.js b/nnotepad/js/tests.js index 5572b5c7..c789dd20 100644 --- a/nnotepad/js/tests.js +++ b/nnotepad/js/tests.js @@ -157,7 +157,7 @@ document.addEventListener('DOMContentLoaded', async (e) => { {dataType: 'float32', shape: [2], buffer: [3, 4]}, ]); - Harness.section('Multiple input tensors'); + Harness.section('Non-operand arguments: array of operands'); await test( `A = [1,2] B = [3,4] concat([A,B], 0)`, {dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]}); @@ -165,6 +165,19 @@ document.addEventListener('DOMContentLoaded', async (e) => { `concat([identity([1,2]),identity([3,4])], 0)`, {dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]}); + Harness.section('Non-operand arguments: array of numbers'); + await test( + `T = [[1,2,3],[4,5,6]] reshape(T, [1, 3, 2, 1])`, + {dataType: 'float32', shape: [1, 3, 2, 1], buffer: [1, 2, 3, 4, 5, 6]}); + await test( + `expand([1], [2, 2])`, + {dataType: 'float32', shape: [2, 2], buffer: [1, 1, 1, 1]}); + + Harness.section('Non-operand arguments: simple numbers'); + await test( + `softmax([1], 0)`, + {dataType: 'float32', shape: [1], buffer: [1]}); + Harness.section('Regression tests'); await test( `concat([[1,2],[3,4]], 0)`,