Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
wasm: Binary Operators, MatMul operator, shape_utils, and broadcast_u…
Browse files Browse the repository at this point in the history
…tils (#47)

* Binary ops for Wasm backend

* MatMul op for Wasm backend
  • Loading branch information
hariharans29 authored Jan 25, 2019
1 parent 67e0204 commit 0be3102
Show file tree
Hide file tree
Showing 27 changed files with 831 additions and 72 deletions.
16 changes: 8 additions & 8 deletions docs/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
|:------------------------------------------------------------------------------------------------------:|:-----------:|:------------:|:-------------:|
| [Abs](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Abs) | :heavy_check_mark: | | :heavy_check_mark: |
| [Acos](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Acos) | :heavy_check_mark: | | :heavy_check_mark: |
| [Add](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Add) | :heavy_check_mark: | | :heavy_check_mark: |
| [And](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#And) | :heavy_check_mark: | | :heavy_check_mark: |
| [Add](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Add) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [And](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#And) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [ArgMax](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#ArgMax) | :heavy_check_mark: | | |
| [Asin](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Asin) | :heavy_check_mark: | | :heavy_check_mark: |
| [Atan](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Atan) | :heavy_check_mark: | | :heavy_check_mark: |
Expand All @@ -18,7 +18,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Constant](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Constant) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Conv](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Conv) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Cos](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Cos) | :heavy_check_mark: | | :heavy_check_mark: |
| [Div](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Div) | :heavy_check_mark: | | :heavy_check_mark: |
| [Div](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Div) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Dropout](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Dropout) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Equal](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Equal) | | | :heavy_check_mark: |
| [Exp](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Exp) | :heavy_check_mark: | | :heavy_check_mark: |
Expand All @@ -34,12 +34,12 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Less](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Less) | | | :heavy_check_mark: |
| [Log](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Log) | :heavy_check_mark: | | :heavy_check_mark: |
| [LRN](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#LRN) | :heavy_check_mark: | | |
| [MatMul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#MatMul) | :heavy_check_mark: | | :heavy_check_mark: |
| [MatMul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#MatMul) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [MaxPool](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#MaxPool) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Mul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Mul) | :heavy_check_mark: | | :heavy_check_mark: |
| [Mul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Mul) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Neg](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Neg) | :heavy_check_mark: | | :heavy_check_mark: |
| [Not](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Not) | | | :heavy_check_mark: |
| [Or](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Or) | :heavy_check_mark: | | :heavy_check_mark: |
| [Or](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Or) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Pad](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Pad) | | | :heavy_check_mark: |
| [Pow](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Pow) | | | :heavy_check_mark: |
| [PRelu](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#PRelu) | :heavy_check_mark: | | :heavy_check_mark: |
Expand All @@ -58,11 +58,11 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Softmax](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Softmax) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Split](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Split) | | | :heavy_check_mark: |
| [Sqrt](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sqrt) | :heavy_check_mark: | | :heavy_check_mark: |
| [Sub](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sub) | :heavy_check_mark: | | :heavy_check_mark: |
| [Sub](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sub) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Sum](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sum) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| [Tan](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Tan) | :heavy_check_mark: | | :heavy_check_mark: |
| [Tanh](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Tanh) | :heavy_check_mark: | | :heavy_check_mark: |
| [Tile](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Tile) | :heavy_check_mark: | | :heavy_check_mark: |
| [Transpose](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Transpose) | :heavy_check_mark: | | :heavy_check_mark: |
| [Unsqueeze](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Unsqueeze) | :heavy_check_mark: | | |
| [Xor](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Xor) | :heavy_check_mark: | | :heavy_check_mark: |
| [Xor](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Xor) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
1 change: 1 addition & 0 deletions lib/backends/cpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class CpuBinaryOp extends BinaryOp {
super(typeConstraint, opType, resultType);
}

// overriding the initialize() in the base class
initialize(attributes: Attribute): void {
if (!this.opType && !this.opLambda) {
throw new Error(`Both opType and opLambda cannot be missing for a binary op`);
Expand Down
34 changes: 7 additions & 27 deletions lib/backends/cpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ndarray from 'ndarray';

import {MatMul} from '../../../ops/matmul';
import {Tensor} from '../../../tensor';
import {BroadcastUtil, ShapeUtil} from '../../../util';
import {BroadcastUtil, MatMulUtil, ShapeUtil} from '../../../util';
import {CpuInferenceHandler} from '../inference-handler';

export class CpuMatMul extends MatMul {
Expand All @@ -16,28 +16,17 @@ export class CpuMatMul extends MatMul {
}

export function matMul(a: Tensor, b: Tensor) {
let dimsA = a.dims.slice(0);
let dimsB = b.dims.slice(0);
// If the first argument is 1-D, it is promoted to a matrix by prepending
// a 1 to its dimensions. After matrix multiplication the prepended 1 is
// removed.
if (a.dims.length === 1) {
dimsA = [1, dimsA[0]];
}
// If the second argument is 1-D, it is promoted to a matrix by appending
// a 1 to its dimensions. After matrix multiplication the appended 1 is
// removed.
if (b.dims.length === 1) {
dimsB = [dimsB[0], 1];
}

let dimsA: number[];
let dimsB: number[];
[dimsA, dimsB] = MatMulUtil.preprocessInputShapes(a.dims.slice(), b.dims.slice());
const mat2dShape = [dimsA[dimsA.length - 2], dimsB[dimsB.length - 1]];
let shape = BroadcastUtil.calcShape(dimsA, dimsB, true);
if (!shape) {
// the inputs cannot broadcast or cannot multiply
throw new Error(`input dimensions do not match the requirement`);
}

// make a copy and re-assign because this can be modified later
shape = shape.slice(0);
const size = ShapeUtil.size(shape);
const num2dMatrices = size / (mat2dShape[0] * mat2dShape[1]);

Expand Down Expand Up @@ -75,18 +64,9 @@ export function matMul(a: Tensor, b: Tensor) {
MatMul2d(subarrayA, subarrayB, ndY, curPos);
curPos += mat2dShape[0] * mat2dShape[1];
}

// Remove prepended dimension if first input is 1d
if (a.dims.length === 1) {
shape = shape.slice(0, shape.length - 2).concat(shape.slice(shape.length - 1));
}
// Remove appended dimension if second input is 1d
if (b.dims.length === 1) {
shape = shape.slice(0, shape.length - 1);
}
MatMulUtil.postprocessOutputShape(shape as number[], a.dims.length, b.dims.length);
const tensorY = new Tensor(shape, isFloat64 ? 'float64' : 'float32');
tensorY.floatData.set(ndY.data);

return tensorY;
}

Expand Down
6 changes: 0 additions & 6 deletions lib/backends/cpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,13 @@ export class CpuUnaryOp extends UnaryOp {
super(typeConstraint, resultType);
}

initialize(attributes: Attribute): void {
this.attributes = attributes;
}

run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] {
// TODO: use webpack + ts-loader + CustomTransformer
// tslint:disable-next-line:max-line-length
// https://github.com/TypeStrong/ts-loader#getcustomtransformers-----before-transformerfactory-after-transformerfactory--
const output = unaryOp(inputs[0], this.func, this.attributes, this.resultType);
return [output];
}

private attributes: Attribute;
}

export function unaryOp(
Expand Down
90 changes: 90 additions & 0 deletions lib/backends/wasm/ops/binary-op.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {BinaryOp} from '../../../ops/binary-op';
import {Tensor} from '../../../tensor';
import {BroadcastUtil} from '../../../util';
import {WasmBinding} from '../../../wasm-binding';
import {WasmInferenceHandler} from '../inference-handler';

export class WasmBinaryOp extends BinaryOp {
constructor(typeConstraint: ReadonlyArray<Tensor.DataType>, opType: string, resultType?: Tensor.DataType) {
super(typeConstraint, opType, resultType);
}

run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
const outputShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false);
if (!outputShape) {
throw new Error('not broadcastable');
}
let fun = '';
// TODO: Explore better ways to deal with types than current `binaryOpType` approach
let binaryOpType = '';
switch (this.opType) {
case 'Add':
if (inputs[0].type === 'float32') {
fun = '_add_f32';
binaryOpType = 'float32InFloat32Out';
}
break;
case 'Sub':
if (inputs[0].type === 'float32') {
fun = '_sub_f32';
binaryOpType = 'float32InFloat32Out';
}
break;
case 'Mul':
if (inputs[0].type === 'float32') {
fun = '_mul_f32';
binaryOpType = 'float32InFloat32Out';
}
break;
case 'Div':
if (inputs[0].type === 'float32') {
fun = '_div_f32';
binaryOpType = 'float32InFloat32Out';
}
break;
case 'PRelu':
if (inputs[0].type === 'float32') {
fun = '_prelu_f32';
binaryOpType = 'float32InFloat32Out';
}
break;
case 'Xor':
fun = '_xor_u8';
binaryOpType = 'boolInBoolOut';
break;
case 'Or':
fun = '_or_u8';
binaryOpType = 'boolInBoolOut';
break;
case 'And':
fun = '_and_u8';
binaryOpType = 'boolInBoolOut';
break;
default:
throw Error(`unsupported binary op by the Wasm backend`);
}
let result: Tensor;
if (binaryOpType === 'float32InFloat32Out') {
result = new Tensor(outputShape, 'float32');
WasmBinding.getInstance().ccall(
fun, [inputs[0].floatData, 'float32ptr'], [inputs[0].dims.length, 'int32'], [inputs[0].dims, 'int32ptr'],
[inputs[1].floatData, 'float32ptr'], [inputs[1].dims.length, 'int32'], [inputs[1].dims, 'int32ptr'],
[result.floatData, 'float32ptr', 'out'], [result.floatData.length, 'int32'], [outputShape.length, 'int32'],
[outputShape, 'int32ptr']);
} else if (binaryOpType === 'boolInBoolOut') {
result = new Tensor(outputShape, 'bool');
WasmBinding.getInstance().ccall(
fun, [inputs[0].integerData as Uint8Array, 'boolptr'], [inputs[0].dims.length, 'int32'],
[inputs[0].dims, 'int32ptr'], [inputs[1].integerData as Uint8Array, 'boolptr'],
[inputs[1].dims.length, 'int32'], [inputs[1].dims, 'int32ptr'],
[result.integerData as Uint8Array, 'boolptr', 'out'], [result.integerData.length, 'int32'],
[outputShape.length, 'int32'], [outputShape, 'int32ptr']);
} else {
throw new Error(`Unsupported binary op format. Probably unsupported data types.`);
}
return [result];
}
}
47 changes: 47 additions & 0 deletions lib/backends/wasm/ops/matmul.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {MatMul} from '../../../ops/matmul';
import {Tensor} from '../../../tensor';
import {BroadcastUtil, MatMulUtil, ShapeUtil} from '../../../util';
import {WasmBinding} from '../../../wasm-binding';
import {WasmInferenceHandler} from '../inference-handler';

export class WasmMatMul extends MatMul {
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] {
let dimsA: number[];
let dimsB: number[];
[dimsA, dimsB] = MatMulUtil.preprocessInputShapes(inputs[0].dims.slice(), inputs[1].dims.slice());
const outputShape = BroadcastUtil.calcShape(dimsA, dimsB, true);
if (!outputShape) {
// the inputs cannot broadcast or cannot multiply
throw new Error(`input dimensions do not match the requirement`);
}

const outputSize = ShapeUtil.size(outputShape);
const resultData = new Float32Array(outputSize);
WasmBinding.getInstance().ccall(
'_matmul_f32', [inputs[0].floatData, 'float32ptr'], [inputs[0].dims, 'int32ptr'],
[inputs[0].dims.length, 'int32'], [inputs[1].floatData, 'float32ptr'], [inputs[1].dims, 'int32ptr'],
[inputs[1].dims.length, 'int32'], [resultData, 'float32ptr', 'out'], [resultData.length, 'int32'],
[outputShape, 'int32ptr'], [outputShape.length, 'int32']);
MatMulUtil.postprocessOutputShape(outputShape as number[], inputs[0].dims.length, inputs[1].dims.length);
const result = new Tensor(outputShape, inputs[0].type);
result.floatData.set(resultData);
return [result];
}

// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations
checkInputTypes(inputs: Tensor[]): boolean {
// currently Wasm backend only supports 'float32' input type
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
return false;
}

if (inputs[0].type !== inputs[1].type) {
return false;
}

return true;
}
}
24 changes: 24 additions & 0 deletions lib/backends/wasm/session-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import {Session} from '../../session';
import {resolve} from '../cpu/ops-resolve';
import {WasmInferenceHandler} from './inference-handler';
import {WasmBatchNormalization} from './ops/batch-normalization';
import {WasmBinaryOp} from './ops/binary-op';
import {WasmConv} from './ops/conv';
import {WasmGemm} from './ops/gemm';
import {WasmMatMul} from './ops/matmul';
import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool';
import {WasmSoftmax} from './ops/softmax';
import {WasmSum} from './ops/sum';
Expand All @@ -32,12 +34,32 @@ export class WasmSessionHandler implements SessionHandler {
private createOperator(node: Graph.Node, domain: string, version: number): Operator {
// assume domain=ai.onnx, version=v7
switch (node.opType) {
// Binary arithmetic ops
case 'Add':
return new WasmBinaryOp(['float32'], 'Add');
case 'Sub':
return new WasmBinaryOp(['float32'], 'Sub');
case 'Mul':
return new WasmBinaryOp(['float32'], 'Mul');
case 'Div':
// TODO: Handle division by zero
return new WasmBinaryOp(['float32'], 'Div');
// Binary logical ops
case 'Xor':
return new WasmBinaryOp(['bool'], 'Xor');
case 'Or':
return new WasmBinaryOp(['bool'], 'Or');
case 'And':
return new WasmBinaryOp(['bool'], 'And');
// Misc ops
case 'Conv':
return new WasmConv();
case 'BatchNormalization':
return new WasmBatchNormalization();
case 'Gemm':
return new WasmGemm();
case 'MatMul':
return new WasmMatMul();
case 'Softmax':
return new WasmSoftmax();
case 'Sum':
Expand All @@ -50,6 +72,8 @@ export class WasmSessionHandler implements SessionHandler {
return new WasmGlobalMaxPool();
case 'GlobalAveragePool':
return new WasmGlobalAveragePool();
case 'PRelu':
return new WasmBinaryOp(['float32'], 'PRelu');
default:
if (this.fallbackToCpuOps) {
return resolve(node, domain, version);
Expand Down
Loading

0 comments on commit 0be3102

Please sign in to comment.