Skip to content

Commit

Permalink
[NNotepad] Fix serialization of non-tensor arrays (#264)
Browse files Browse the repository at this point in the history
Most MLGraphBuilder method arguments are operands, but sometimes
they're arrays or numbers. There was a mechanism in place to determine
(via a table) whether a given method's argument was an operator or
not, but that wasn't enough to distinguish an array of numbers (like
reshape() needs) from an array of operands (which concat() needs).

Rework the lookup from returning a boolean to returning an argument
type, and adjust serialization appropriately.

Fixes #263
  • Loading branch information
inexorabletash authored Jul 24, 2024
1 parent 707577f commit 458decc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
63 changes: 40 additions & 23 deletions nnotepad/js/nnotepad.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ export class ComputeError extends Error {
// General WebNN Utilities
// ============================================================

const kArgTypeOperandList = 1;
const kArgTypeNonOperand = 2;
const kArgTypeOperand = 3;

class WebNNUtil {
static bufferForOperand(operand) {
const size = [...operand.shape()].reduce((a, b) => a * b, 1);
Expand Down Expand Up @@ -60,21 +64,22 @@ class WebNNUtil {
throw new Error(`Unsupported dataType ${type}`);
}

static isNonOperandArg(name, index) {
static argumentType(name, index) {
return ({
concat: [0, 1],
expand: [1],
gru: [3, 4],
gruCell: [4],
lstm: [3, 4],
lstmCell: [5],
pad: [1, 2],
reshape: [1],
slice: [1, 2],
softmax: [1], // TODO: Distinguish overloads
split: [1],
concat: {0: kArgTypeOperandList, 1: kArgTypeNonOperand},
expand: {1: kArgTypeNonOperand},
gru: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
gruCell: {4: kArgTypeNonOperand},
lstm: {3: kArgTypeNonOperand, 4: kArgTypeNonOperand},
lstmCell: {5: kArgTypeNonOperand},
pad: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
reshape: {1: kArgTypeNonOperand},
slice: {1: kArgTypeNonOperand, 2: kArgTypeNonOperand},
softmax: {1: kArgTypeNonOperand},
split: {1: kArgTypeNonOperand},
})[name]
?.includes(index);
?.[index] ||
kArgTypeOperand;
}
}

Expand Down Expand Up @@ -379,7 +384,7 @@ export class NNotepad {
}
throw new Error(`unexpected line type: ${line.type}`);
}
function serializeExpr(expr, nonOperand = false) {
function serializeExpr(expr, argumentType = kArgTypeOperand) {
if (expr.op) {
if (expr.lhs) {
return `_.${kBinaryOperators[expr.op]}(${serializeExpr(expr.lhs)}, ${
Expand All @@ -394,11 +399,21 @@ export class NNotepad {
case 'boolean':
return String(expr.value);
case 'number':
return nonOperand ? Util.stringify(expr.value) :
serializeScalar(expr.value, expr.dataType);
switch (argumentType) {
case kArgTypeNonOperand:
return Util.stringify(expr.value);
default:
return serializeScalar(expr.value, expr.dataType);
}
case 'array':
return nonOperand ? serializeArray(expr.value) :
serializeTensor(expr.value, expr.dataType);
switch (argumentType) {
case kArgTypeNonOperand:
return serializeArray(expr.value, kArgTypeNonOperand);
case kArgTypeOperandList:
return serializeArray(expr.value, kArgTypeOperand);
default:
return serializeTensor(expr.value, expr.dataType);
}
case 'dict':
return serializeDict(expr.dict);
case 'identifier':
Expand All @@ -414,7 +429,7 @@ export class NNotepad {
.map((k) => {
const v = dict[k];
k = Util.stringify(k);
return `${k}: ${serializeExpr(v, true)}`;
return `${k}: ${serializeExpr(v, kArgTypeNonOperand)}`;
})
.join(', ') +
'}';
Expand Down Expand Up @@ -465,8 +480,10 @@ export class NNotepad {
elements.map((n) => Util.stringifyNumber(n, dataType)).join(',')}]))`;
}

function serializeArray(array) {
return '[' + array.map((expr) => serializeExpr(expr)).join(', ') + ']';
function serializeArray(array, argumentType) {
return '[' +
array.map((expr) => serializeExpr(expr, argumentType)).join(', ') +
']';
}

function serializeCall(name, args) {
Expand Down Expand Up @@ -506,8 +523,8 @@ export class NNotepad {

return `_.${name}(${
args.map(
(arg, index) => serializeExpr(
arg, WebNNUtil.isNonOperandArg(name, index)))
(arg, index) =>
serializeExpr(arg, WebNNUtil.argumentType(name, index)))
.join(', ')})`;
}
}
Expand Down
15 changes: 14 additions & 1 deletion nnotepad/js/tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,27 @@ document.addEventListener('DOMContentLoaded', async (e) => {
{dataType: 'float32', shape: [2], buffer: [3, 4]},
]);

Harness.section('Multiple input tensors');
Harness.section('Non-operand arguments: array of operands');
await test(
`A = [1,2] B = [3,4] concat([A,B], 0)`,
{dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]});
await test(
`concat([identity([1,2]),identity([3,4])], 0)`,
{dataType: 'float32', shape: [4], buffer: [1, 2, 3, 4]});

Harness.section('Non-operand arguments: array of numbers');
await test(
`T = [[1,2,3],[4,5,6]] reshape(T, [1, 3, 2, 1])`,
{dataType: 'float32', shape: [1, 3, 2, 1], buffer: [1, 2, 3, 4, 5, 6]});
await test(
`expand([1], [2, 2])`,
{dataType: 'float32', shape: [2, 2], buffer: [1, 1, 1, 1]});

Harness.section('Non-operand arguments: simple numbers');
await test(
`softmax([1], 0)`,
{dataType: 'float32', shape: [1], buffer: [1]});

Harness.section('Regression tests');
await test(
`concat([[1,2],[3,4]], 0)`,
Expand Down

0 comments on commit 458decc

Please sign in to comment.