This repository has been archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
wasm: Binary Operators, MatMul operator, shape_utils, and broadcast_u…
…tils (#47) * Binary ops for Wasm backend * MatMul op for Wasm backend
- Loading branch information
1 parent
67e0204
commit 0be3102
Showing
27 changed files
with
831 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {BinaryOp} from '../../../ops/binary-op'; | ||
import {Tensor} from '../../../tensor'; | ||
import {BroadcastUtil} from '../../../util'; | ||
import {WasmBinding} from '../../../wasm-binding'; | ||
import {WasmInferenceHandler} from '../inference-handler'; | ||
|
||
export class WasmBinaryOp extends BinaryOp { | ||
constructor(typeConstraint: ReadonlyArray<Tensor.DataType>, opType: string, resultType?: Tensor.DataType) { | ||
super(typeConstraint, opType, resultType); | ||
} | ||
|
||
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
const outputShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false); | ||
if (!outputShape) { | ||
throw new Error('not broadcastable'); | ||
} | ||
let fun = ''; | ||
// TODO: Explore better ways to deal with types than current `binaryOpType` approach | ||
let binaryOpType = ''; | ||
switch (this.opType) { | ||
case 'Add': | ||
if (inputs[0].type === 'float32') { | ||
fun = '_add_f32'; | ||
binaryOpType = 'float32InFloat32Out'; | ||
} | ||
break; | ||
case 'Sub': | ||
if (inputs[0].type === 'float32') { | ||
fun = '_sub_f32'; | ||
binaryOpType = 'float32InFloat32Out'; | ||
} | ||
break; | ||
case 'Mul': | ||
if (inputs[0].type === 'float32') { | ||
fun = '_mul_f32'; | ||
binaryOpType = 'float32InFloat32Out'; | ||
} | ||
break; | ||
case 'Div': | ||
if (inputs[0].type === 'float32') { | ||
fun = '_div_f32'; | ||
binaryOpType = 'float32InFloat32Out'; | ||
} | ||
break; | ||
case 'PRelu': | ||
if (inputs[0].type === 'float32') { | ||
fun = '_prelu_f32'; | ||
binaryOpType = 'float32InFloat32Out'; | ||
} | ||
break; | ||
case 'Xor': | ||
fun = '_xor_u8'; | ||
binaryOpType = 'boolInBoolOut'; | ||
break; | ||
case 'Or': | ||
fun = '_or_u8'; | ||
binaryOpType = 'boolInBoolOut'; | ||
break; | ||
case 'And': | ||
fun = '_and_u8'; | ||
binaryOpType = 'boolInBoolOut'; | ||
break; | ||
default: | ||
throw Error(`unsupported binary op by the Wasm backend`); | ||
} | ||
let result: Tensor; | ||
if (binaryOpType === 'float32InFloat32Out') { | ||
result = new Tensor(outputShape, 'float32'); | ||
WasmBinding.getInstance().ccall( | ||
fun, [inputs[0].floatData, 'float32ptr'], [inputs[0].dims.length, 'int32'], [inputs[0].dims, 'int32ptr'], | ||
[inputs[1].floatData, 'float32ptr'], [inputs[1].dims.length, 'int32'], [inputs[1].dims, 'int32ptr'], | ||
[result.floatData, 'float32ptr', 'out'], [result.floatData.length, 'int32'], [outputShape.length, 'int32'], | ||
[outputShape, 'int32ptr']); | ||
} else if (binaryOpType === 'boolInBoolOut') { | ||
result = new Tensor(outputShape, 'bool'); | ||
WasmBinding.getInstance().ccall( | ||
fun, [inputs[0].integerData as Uint8Array, 'boolptr'], [inputs[0].dims.length, 'int32'], | ||
[inputs[0].dims, 'int32ptr'], [inputs[1].integerData as Uint8Array, 'boolptr'], | ||
[inputs[1].dims.length, 'int32'], [inputs[1].dims, 'int32ptr'], | ||
[result.integerData as Uint8Array, 'boolptr', 'out'], [result.integerData.length, 'int32'], | ||
[outputShape.length, 'int32'], [outputShape, 'int32ptr']); | ||
} else { | ||
throw new Error(`Unsupported binary op format. Probably unsupported data types.`); | ||
} | ||
return [result]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT license. | ||
|
||
import {MatMul} from '../../../ops/matmul'; | ||
import {Tensor} from '../../../tensor'; | ||
import {BroadcastUtil, MatMulUtil, ShapeUtil} from '../../../util'; | ||
import {WasmBinding} from '../../../wasm-binding'; | ||
import {WasmInferenceHandler} from '../inference-handler'; | ||
|
||
export class WasmMatMul extends MatMul { | ||
run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { | ||
let dimsA: number[]; | ||
let dimsB: number[]; | ||
[dimsA, dimsB] = MatMulUtil.preprocessInputShapes(inputs[0].dims.slice(), inputs[1].dims.slice()); | ||
const outputShape = BroadcastUtil.calcShape(dimsA, dimsB, true); | ||
if (!outputShape) { | ||
// the inputs cannot broadcast or cannot multiply | ||
throw new Error(`input dimensions do not match the requirement`); | ||
} | ||
|
||
const outputSize = ShapeUtil.size(outputShape); | ||
const resultData = new Float32Array(outputSize); | ||
WasmBinding.getInstance().ccall( | ||
'_matmul_f32', [inputs[0].floatData, 'float32ptr'], [inputs[0].dims, 'int32ptr'], | ||
[inputs[0].dims.length, 'int32'], [inputs[1].floatData, 'float32ptr'], [inputs[1].dims, 'int32ptr'], | ||
[inputs[1].dims.length, 'int32'], [resultData, 'float32ptr', 'out'], [resultData.length, 'int32'], | ||
[outputShape, 'int32ptr'], [outputShape.length, 'int32']); | ||
MatMulUtil.postprocessOutputShape(outputShape as number[], inputs[0].dims.length, inputs[1].dims.length); | ||
const result = new Tensor(outputShape, inputs[0].type); | ||
result.floatData.set(resultData); | ||
return [result]; | ||
} | ||
|
||
// overriding the checkInputTypes() in the base class because Wasm backend has special type limitations | ||
checkInputTypes(inputs: Tensor[]): boolean { | ||
// currently Wasm backend only supports 'float32' input type | ||
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') { | ||
return false; | ||
} | ||
|
||
if (inputs[0].type !== inputs[1].type) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.