diff --git a/docs/operators.md b/docs/operators.md index d493b32d..81077e7e 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -6,8 +6,8 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat |:------------------------------------------------------------------------------------------------------:|:-----------:|:------------:|:-------------:| | [Abs](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Abs) | :heavy_check_mark: | | :heavy_check_mark: | | [Acos](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Acos) | :heavy_check_mark: | | :heavy_check_mark: | -| [Add](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Add) | :heavy_check_mark: | | :heavy_check_mark: | -| [And](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#And) | :heavy_check_mark: | | :heavy_check_mark: | +| [Add](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Add) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| [And](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#And) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [ArgMax](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#ArgMax) | :heavy_check_mark: | | | | [Asin](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Asin) | :heavy_check_mark: | | :heavy_check_mark: | | [Atan](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Atan) | :heavy_check_mark: | | :heavy_check_mark: | @@ -18,7 +18,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Constant](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Constant) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Conv](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Conv) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Cos](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Cos) | :heavy_check_mark: | | :heavy_check_mark: | -| [Div](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Div) | :heavy_check_mark: | | :heavy_check_mark: | +| [Div](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Div) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Dropout](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Dropout) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Equal](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Equal) | | | :heavy_check_mark: | | [Exp](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Exp) | :heavy_check_mark: | | :heavy_check_mark: | @@ -34,12 +34,12 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Less](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Less) | | | :heavy_check_mark: | | [Log](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Log) | :heavy_check_mark: | | :heavy_check_mark: | | [LRN](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#LRN) | :heavy_check_mark: | | | -| [MatMul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#MatMul) | :heavy_check_mark: | | :heavy_check_mark: | +| [MatMul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#MatMul) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [MaxPool](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#MaxPool) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | -| [Mul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Mul) | :heavy_check_mark: | | :heavy_check_mark: | +| [Mul](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Mul) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Neg](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Neg) | :heavy_check_mark: | | :heavy_check_mark: | | [Not](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Not) | | | :heavy_check_mark: | -| [Or](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Or) | :heavy_check_mark: | | :heavy_check_mark: | +| [Or](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Or) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Pad](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Pad) | | | :heavy_check_mark: | | [Pow](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Pow) | | | :heavy_check_mark: | | [PRelu](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#PRelu) | :heavy_check_mark: | | :heavy_check_mark: | @@ -58,11 +58,11 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Softmax](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Softmax) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Split](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Split) | | | :heavy_check_mark: | | [Sqrt](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sqrt) | :heavy_check_mark: | | :heavy_check_mark: | -| [Sub](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sub) | :heavy_check_mark: | | :heavy_check_mark: | +| [Sub](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sub) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Sum](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Sum) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | [Tan](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Tan) | :heavy_check_mark: | | :heavy_check_mark: | | [Tanh](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Tanh) | :heavy_check_mark: | | :heavy_check_mark: | | [Tile](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Tile) | :heavy_check_mark: | | :heavy_check_mark: | | [Transpose](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Transpose) | :heavy_check_mark: | | :heavy_check_mark: | | [Unsqueeze](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Unsqueeze) | :heavy_check_mark: | | | -| [Xor](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Xor) | :heavy_check_mark: | | :heavy_check_mark: | +| [Xor](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md#Xor) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | diff --git a/lib/backends/cpu/ops/binary-op.ts b/lib/backends/cpu/ops/binary-op.ts index 17efc79b..114a5a16 100644 --- a/lib/backends/cpu/ops/binary-op.ts +++ b/lib/backends/cpu/ops/binary-op.ts @@ -16,6 +16,7 @@ export class CpuBinaryOp extends BinaryOp { super(typeConstraint, opType, resultType); } + // overriding the initialize() in the base class initialize(attributes: Attribute): void { if (!this.opType && !this.opLambda) { throw new Error(`Both opType and opLambda cannot be missing for a binary op`); diff --git a/lib/backends/cpu/ops/matmul.ts b/lib/backends/cpu/ops/matmul.ts index d59921da..501795de 100644 --- a/lib/backends/cpu/ops/matmul.ts +++ b/lib/backends/cpu/ops/matmul.ts @@ -5,7 +5,7 @@ import ndarray from 'ndarray'; import {MatMul} from '../../../ops/matmul'; import {Tensor} from '../../../tensor'; -import {BroadcastUtil, ShapeUtil} from '../../../util'; +import {BroadcastUtil, MatMulUtil, ShapeUtil} from '../../../util'; import {CpuInferenceHandler} from '../inference-handler'; export class CpuMatMul extends MatMul { @@ -16,28 +16,17 @@ export class CpuMatMul extends MatMul { } export function matMul(a: Tensor, b: Tensor) { - let dimsA = a.dims.slice(0); - let dimsB = b.dims.slice(0); - // If the first argument is 1-D, it is promoted to a matrix by prepending - // a 1 to its dimensions. After matrix multiplication the prepended 1 is - // removed. - if (a.dims.length === 1) { - dimsA = [1, dimsA[0]]; - } - // If the second argument is 1-D, it is promoted to a matrix by appending - // a 1 to its dimensions. After matrix multiplication the appended 1 is - // removed. - if (b.dims.length === 1) { - dimsB = [dimsB[0], 1]; - } - + let dimsA: number[]; + let dimsB: number[]; + [dimsA, dimsB] = MatMulUtil.preprocessInputShapes(a.dims.slice(), b.dims.slice()); const mat2dShape = [dimsA[dimsA.length - 2], dimsB[dimsB.length - 1]]; let shape = BroadcastUtil.calcShape(dimsA, dimsB, true); if (!shape) { // the inputs cannot broadcast or cannot multiply throw new Error(`input dimensions do not match the requirement`); } - + // make a copy and re-assign because this can be modified later + shape = shape.slice(0); const size = ShapeUtil.size(shape); const num2dMatrices = size / (mat2dShape[0] * mat2dShape[1]); @@ -75,18 +64,9 @@ export function matMul(a: Tensor, b: Tensor) { MatMul2d(subarrayA, subarrayB, ndY, curPos); curPos += mat2dShape[0] * mat2dShape[1]; } - - // Remove prepended dimension if first input is 1d - if (a.dims.length === 1) { - shape = shape.slice(0, shape.length - 2).concat(shape.slice(shape.length - 1)); - } - // Remove appended dimension if second input is 1d - if (b.dims.length === 1) { - shape = shape.slice(0, shape.length - 1); - } + MatMulUtil.postprocessOutputShape(shape as number[], a.dims.length, b.dims.length); const tensorY = new Tensor(shape, isFloat64 ? 'float64' : 'float32'); tensorY.floatData.set(ndY.data); - return tensorY; } diff --git a/lib/backends/cpu/ops/unary-op.ts b/lib/backends/cpu/ops/unary-op.ts index 90632c6b..a2d7707d 100644 --- a/lib/backends/cpu/ops/unary-op.ts +++ b/lib/backends/cpu/ops/unary-op.ts @@ -14,10 +14,6 @@ export class CpuUnaryOp extends UnaryOp { super(typeConstraint, resultType); } - initialize(attributes: Attribute): void { - this.attributes = attributes; - } - run(inferenceHandler: CpuInferenceHandler, inputs: Tensor[]): Tensor[] { // TODO: use webpack + ts-loader + CustomTransformer // tslint:disable-next-line:max-line-length @@ -25,8 +21,6 @@ export class CpuUnaryOp extends UnaryOp { const output = unaryOp(inputs[0], this.func, this.attributes, this.resultType); return [output]; } - - private attributes: Attribute; } export function unaryOp( diff --git a/lib/backends/wasm/ops/binary-op.ts b/lib/backends/wasm/ops/binary-op.ts new file mode 100644 index 00000000..948a0419 --- /dev/null +++ b/lib/backends/wasm/ops/binary-op.ts @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {BinaryOp} from '../../../ops/binary-op'; +import {Tensor} from '../../../tensor'; +import {BroadcastUtil} from '../../../util'; +import {WasmBinding} from '../../../wasm-binding'; +import {WasmInferenceHandler} from '../inference-handler'; + +export class WasmBinaryOp extends BinaryOp { + constructor(typeConstraint: ReadonlyArray, opType: string, resultType?: Tensor.DataType) { + super(typeConstraint, opType, resultType); + } + + run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { + const outputShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false); + if (!outputShape) { + throw new Error('not broadcastable'); + } + let fun = ''; + // TODO: Explore better ways to deal with types than current `binaryOpType` approach + let binaryOpType = ''; + switch (this.opType) { + case 'Add': + if (inputs[0].type === 'float32') { + fun = '_add_f32'; + binaryOpType = 'float32InFloat32Out'; + } + break; + case 'Sub': + if (inputs[0].type === 'float32') { + fun = '_sub_f32'; + binaryOpType = 'float32InFloat32Out'; + } + break; + case 'Mul': + if (inputs[0].type === 'float32') { + fun = '_mul_f32'; + binaryOpType = 'float32InFloat32Out'; + } + break; + case 'Div': + if (inputs[0].type === 'float32') { + fun = '_div_f32'; + binaryOpType = 'float32InFloat32Out'; + } + break; + case 'PRelu': + if (inputs[0].type === 'float32') { + fun = '_prelu_f32'; + binaryOpType = 'float32InFloat32Out'; + } + break; + case 'Xor': + fun = '_xor_u8'; + binaryOpType = 'boolInBoolOut'; + break; + case 'Or': + fun = '_or_u8'; + binaryOpType = 'boolInBoolOut'; + break; + case 'And': + fun = '_and_u8'; + binaryOpType = 'boolInBoolOut'; + break; + default: + throw Error(`unsupported binary op by the Wasm backend`); + } + let result: Tensor; + if (binaryOpType === 'float32InFloat32Out') { + result = new Tensor(outputShape, 'float32'); + WasmBinding.getInstance().ccall( + fun, [inputs[0].floatData, 'float32ptr'], [inputs[0].dims.length, 'int32'], [inputs[0].dims, 'int32ptr'], + [inputs[1].floatData, 'float32ptr'], [inputs[1].dims.length, 'int32'], [inputs[1].dims, 'int32ptr'], + [result.floatData, 'float32ptr', 'out'], [result.floatData.length, 'int32'], [outputShape.length, 'int32'], + [outputShape, 'int32ptr']); + } else if (binaryOpType === 'boolInBoolOut') { + result = new Tensor(outputShape, 'bool'); + WasmBinding.getInstance().ccall( + fun, [inputs[0].integerData as Uint8Array, 'boolptr'], [inputs[0].dims.length, 'int32'], + [inputs[0].dims, 'int32ptr'], [inputs[1].integerData as Uint8Array, 'boolptr'], + [inputs[1].dims.length, 'int32'], [inputs[1].dims, 'int32ptr'], + [result.integerData as Uint8Array, 'boolptr', 'out'], [result.integerData.length, 'int32'], + [outputShape.length, 'int32'], [outputShape, 'int32ptr']); + } else { + throw new Error(`Unsupported binary op format. Probably unsupported data types.`); + } + return [result]; + } +} diff --git a/lib/backends/wasm/ops/matmul.ts b/lib/backends/wasm/ops/matmul.ts new file mode 100644 index 00000000..8caf6c61 --- /dev/null +++ b/lib/backends/wasm/ops/matmul.ts @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +import {MatMul} from '../../../ops/matmul'; +import {Tensor} from '../../../tensor'; +import {BroadcastUtil, MatMulUtil, ShapeUtil} from '../../../util'; +import {WasmBinding} from '../../../wasm-binding'; +import {WasmInferenceHandler} from '../inference-handler'; + +export class WasmMatMul extends MatMul { + run(inferenceHandler: WasmInferenceHandler, inputs: Tensor[]): Tensor[] { + let dimsA: number[]; + let dimsB: number[]; + [dimsA, dimsB] = MatMulUtil.preprocessInputShapes(inputs[0].dims.slice(), inputs[1].dims.slice()); + const outputShape = BroadcastUtil.calcShape(dimsA, dimsB, true); + if (!outputShape) { + // the inputs cannot broadcast or cannot multiply + throw new Error(`input dimensions do not match the requirement`); + } + + const outputSize = ShapeUtil.size(outputShape); + const resultData = new Float32Array(outputSize); + WasmBinding.getInstance().ccall( + '_matmul_f32', [inputs[0].floatData, 'float32ptr'], [inputs[0].dims, 'int32ptr'], + [inputs[0].dims.length, 'int32'], [inputs[1].floatData, 'float32ptr'], [inputs[1].dims, 'int32ptr'], + [inputs[1].dims.length, 'int32'], [resultData, 'float32ptr', 'out'], [resultData.length, 'int32'], + [outputShape, 'int32ptr'], [outputShape.length, 'int32']); + MatMulUtil.postprocessOutputShape(outputShape as number[], inputs[0].dims.length, inputs[1].dims.length); + const result = new Tensor(outputShape, inputs[0].type); + result.floatData.set(resultData); + return [result]; + } + + // overriding the checkInputTypes() in the base class because Wasm backend has special type limitations + checkInputTypes(inputs: Tensor[]): boolean { + // currently Wasm backend only supports 'float32' input type + if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') { + return false; + } + + if (inputs[0].type !== inputs[1].type) { + return false; + } + + return true; + } +} diff --git a/lib/backends/wasm/session-handler.ts b/lib/backends/wasm/session-handler.ts index 5a907469..1a85c832 100644 --- a/lib/backends/wasm/session-handler.ts +++ b/lib/backends/wasm/session-handler.ts @@ -8,8 +8,10 @@ import {Session} from '../../session'; import {resolve} from '../cpu/ops-resolve'; import {WasmInferenceHandler} from './inference-handler'; import {WasmBatchNormalization} from './ops/batch-normalization'; +import {WasmBinaryOp} from './ops/binary-op'; import {WasmConv} from './ops/conv'; import {WasmGemm} from './ops/gemm'; +import {WasmMatMul} from './ops/matmul'; import {WasmAveragePool, WasmGlobalAveragePool, WasmGlobalMaxPool, WasmMaxPool} from './ops/pool'; import {WasmSoftmax} from './ops/softmax'; import {WasmSum} from './ops/sum'; @@ -32,12 +34,32 @@ export class WasmSessionHandler implements SessionHandler { private createOperator(node: Graph.Node, domain: string, version: number): Operator { // assume domain=ai.onnx, version=v7 switch (node.opType) { + // Binary arithmetic ops + case 'Add': + return new WasmBinaryOp(['float32'], 'Add'); + case 'Sub': + return new WasmBinaryOp(['float32'], 'Sub'); + case 'Mul': + return new WasmBinaryOp(['float32'], 'Mul'); + case 'Div': + // TODO: Handle division by zero + return new WasmBinaryOp(['float32'], 'Div'); + // Binary logical ops + case 'Xor': + return new WasmBinaryOp(['bool'], 'Xor'); + case 'Or': + return new WasmBinaryOp(['bool'], 'Or'); + case 'And': + return new WasmBinaryOp(['bool'], 'And'); + // Misc ops case 'Conv': return new WasmConv(); case 'BatchNormalization': return new WasmBatchNormalization(); case 'Gemm': return new WasmGemm(); + case 'MatMul': + return new WasmMatMul(); case 'Softmax': return new WasmSoftmax(); case 'Sum': @@ -50,6 +72,8 @@ export class WasmSessionHandler implements SessionHandler { return new WasmGlobalMaxPool(); case 'GlobalAveragePool': return new WasmGlobalAveragePool(); + case 'PRelu': + return new WasmBinaryOp(['float32'], 'PRelu'); default: if (this.fallbackToCpuOps) { return resolve(node, domain, version); diff --git a/lib/backends/webgl/ops/binary-op.ts b/lib/backends/webgl/ops/binary-op.ts index 8a769dc4..a689edd2 100644 --- a/lib/backends/webgl/ops/binary-op.ts +++ b/lib/backends/webgl/ops/binary-op.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -import {Attribute} from '../../../attribute'; import {BinaryOp} from '../../../ops/binary-op'; import {Tensor} from '../../../tensor'; import {BroadcastUtil, ShapeUtil} from '../../../util'; @@ -14,11 +13,10 @@ import {WebGLOperatorHelper} from '../webgl-operator-utils'; export class WebGLBinaryOp extends BinaryOp implements WebGLOperator { constructor( - protected typeConstraint: ReadonlyArray, protected glslFunc: GlslValueFunction, - protected outputType?: Tensor.DataType) { - super(typeConstraint); + typeConstraint: ReadonlyArray, protected glslFunc: GlslValueFunction, opType?: string, + resultType?: Tensor.DataType) { + super(typeConstraint, opType, resultType); } - initialize(attributes: Attribute): void {} run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { return WebGLOperatorHelper.run(this, inferenceHandler, inputs); } @@ -76,7 +74,7 @@ export class WebGLBinaryOp extends BinaryOp implements WebGLOperator { return { inputTextureDatas: inputTDs, outputTextureData: handler.createTextureDataFromLayout( - programInfo.outputLayout, this.outputType ? this.outputType : inputs[0].type), + programInfo.outputLayout, this.resultType ? this.resultType : inputs[0].type), uniformData: {} }; } diff --git a/lib/backends/webgl/ops/matmul.ts b/lib/backends/webgl/ops/matmul.ts index 3eb5431b..60edc71a 100644 --- a/lib/backends/webgl/ops/matmul.ts +++ b/lib/backends/webgl/ops/matmul.ts @@ -15,15 +15,15 @@ export class WebGLMatMul extends MatMul implements WebGLOperator { return WebGLOperatorHelper.run(this, inferenceHandler, inputs); } createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { - const aShape = inputs[0].dims.slice(); - const bShape = inputs[1].dims.slice(); + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); if (!outputShape) { throw new Error('Can\'t use matmul on the given tensors'); } const rank = outputShape.length; - const arank = inputs[0].dims.slice().length; - const brank = inputs[1].dims.slice().length; + const arank = aShape.length; + const brank = bShape.length; const sharedDim = aShape[aShape.length - 1]; const shaderSource = ` uniform sampler2D A; diff --git a/lib/backends/webgl/session-handler.ts b/lib/backends/webgl/session-handler.ts index f08dedf0..44081fcb 100644 --- a/lib/backends/webgl/session-handler.ts +++ b/lib/backends/webgl/session-handler.ts @@ -120,7 +120,7 @@ export class WebGLSessionHandler implements SessionHandler { case 'Dropout': return new WebGLDropout(); case 'Equal': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), 'bool'); + return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), undefined, 'bool'); case 'Exp': return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslExp()); case 'Floor': @@ -134,7 +134,7 @@ export class WebGLSessionHandler implements SessionHandler { case 'GlobalMaxPool': return new WebGLGlobalMaxPool(); case 'Greater': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), 'bool'); + return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), undefined, 'bool'); case 'Identity': return new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslIdentity()); case 'ImageScaler': @@ -142,7 +142,7 @@ export class WebGLSessionHandler implements SessionHandler { case 'LeakyRelu': return new WebGLLeakyRelu(); case 'Less': - return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), 'bool'); + return new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool'); case 'Log': return new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslLog()); case 'MatMul': diff --git a/lib/ops/binary-op.ts b/lib/ops/binary-op.ts index 445bda79..d24c2eab 100644 --- a/lib/ops/binary-op.ts +++ b/lib/ops/binary-op.ts @@ -12,7 +12,8 @@ export abstract class BinaryOp implements Operator { protected resultType?: Tensor.DataType) {} abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; - abstract initialize(attributes: Attribute): void; + + initialize(attributes: Attribute): void {} checkInputs(inputs: Tensor[]): boolean { if (!inputs || inputs.length !== 2) { diff --git a/lib/ops/matmul.ts b/lib/ops/matmul.ts index 61b33217..eae672f9 100644 --- a/lib/ops/matmul.ts +++ b/lib/ops/matmul.ts @@ -32,6 +32,10 @@ export abstract class MatMul implements Operator { return false; } + if (inputs[0].type !== inputs[1].type) { + return false; + } + return true; } } diff --git a/lib/ops/unary-op.ts b/lib/ops/unary-op.ts index 4e6f9bf9..ab1c79da 100644 --- a/lib/ops/unary-op.ts +++ b/lib/ops/unary-op.ts @@ -10,7 +10,10 @@ export abstract class UnaryOp implements Operator { constructor(protected typeConstraint: ReadonlyArray, protected resultType?: Tensor.DataType) {} abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; - abstract initialize(attributes: Attribute): void; + + initialize(attributes: Attribute): void { + this.attributes = attributes; + } checkInputs(inputs: Tensor[]): boolean { if (!inputs || inputs.length !== 1) { @@ -27,4 +30,6 @@ export abstract class UnaryOp implements Operator { return true; } + + protected attributes: Attribute; } diff --git a/lib/util.ts b/lib/util.ts index 3e05798f..07dde8a4 100644 --- a/lib/util.ts +++ b/lib/util.ts @@ -25,6 +25,60 @@ export function checkInputsShape(inputs: Tensor[], ...expectedDimensions: number return true; } +export class MatMulUtil { + /** + * Fix the input shapes for MatMul operation if they need fixing + * @param dimsA The shape of tensor A. Should be an array of positive integers + * @param dimsB The shape of tensor B. Should be an array of positive integers + * @returns A tuple containing the preprocessed input shapes as required by ONNX specifications + */ + static preprocessInputShapes(dimsA: number[], dimsB: number[]): [number[], number[]] { + // If the first argument is 1-D, it is promoted to a matrix by prepending + // a 1 to its dimensions. After matrix multiplication the prepended 1 is + // removed. + if (dimsA.length === 1) { + dimsA = [1, dimsA[0]]; + } + // If the second argument is 1-D, it is promoted to a matrix by appending + // a 1 to its dimensions. After matrix multiplication the appended 1 is + // removed. + if (dimsB.length === 1) { + dimsB = [dimsB[0], 1]; + } + + return [dimsA, dimsB]; + } + + /** + * Fix the output shape computed for MatMul operation if it needs fixing + * @param outputShape The computed outputShape. Should be an array (atleast of length 2) of positive integers. + * This will be mutated. + * @param aRank The rank of tensor A. + * @param bRank The rank of tensor B. + */ + static postprocessOutputShape(outputShape: number[], aRank: number, bRank: number) { + // Remove prepended dimension if first input is 1d + if (aRank === 1) { + // outputShape = outputShape.slice(0, outputShape.length - 2).concat(outputShape.slice(outputShape.length - 1)); + outputShape.splice(outputShape.length - 2, 1); + } + // Remove appended dimension if second input is 1d + if (bRank === 1) { + outputShape.pop(); + } + } + + /** + * Calculate the expected shape when matrix multiplication + * @param a The shape of tensor A. Should be a tuple of 2 positive integers + * @param b The shape of tensor B. Should be a tuple of 2 positive integers + * @returns The expected shape of the result, or undefined if N/A + */ + static calcMatMulShape(a: [number, number], b: [number, number]): [number, number]|undefined { + return (a[1] !== b[0]) ? undefined : [a[0], b[1]]; + } +} + export class BroadcastUtil { /** * Calculate the expected shape when broadcasting 2 tensors @@ -52,7 +106,7 @@ export class BroadcastUtil { return undefined; } const cShapeMatMul = - BroadcastUtil.calcMatMulShape([adims[arank - 2], adims[arank - 1]], [bdims[brank - 2], bdims[brank - 1]]); + MatMulUtil.calcMatMulShape([adims[arank - 2], adims[arank - 1]], [bdims[brank - 2], bdims[brank - 1]]); if (cShapeMatMul === undefined) { return undefined; } @@ -72,16 +126,6 @@ export class BroadcastUtil { return cdims; } - /** - * Calculate the expected shape when matrix multiplication - * @param a The shape of tensor A. Should be a tuple of 2 positive integers - * @param b The shape of tensor B. Should be a tuple of 2 positive integers - * @returns The expected shape of the result, or undefined if N/A - */ - static calcMatMulShape(a: [number, number], b: [number, number]): [number, number]|undefined { - return (a[1] !== b[0]) ? undefined : [a[0], b[1]]; - } - /** * Given the indices of a broadcasted tensor, calculate the original indices * @param broadcastedIndices The given indices of the broadcasted tensor. diff --git a/lib/wasm-binding-core.ts b/lib/wasm-binding-core.ts index c609b2b3..457538c5 100644 --- a/lib/wasm-binding-core.ts +++ b/lib/wasm-binding-core.ts @@ -25,8 +25,9 @@ export interface WasmCallArgumentTypeMap { int32: number; float32: number; float64: number; - int32ptr: ReadonlyArray|Int32Array|null; - float32ptr: ReadonlyArray|Float32Array|null; + boolptr: ReadonlyArray|Uint8Array; + int32ptr: ReadonlyArray|Uint32Array|Int32Array|null; + float32ptr: ReadonlyArray|Int32Array|Uint32Array|Float32Array|null; float64ptr: ReadonlyArray|Float64Array|null; } @@ -173,6 +174,17 @@ export class WasmBinding { case 'float64': len = 8; break; + case 'boolptr': + if (!paramData) { + // deal with nullptr + offset.push(0); + continue; + } else if (Array.isArray(paramData) || ArrayBuffer.isView(paramData)) { + len = 4 * Math.ceil(paramData.length / 4); + } else { + throw new Error(`boolptr requires boolean array or Uint8Array`); + } + break; case 'int32ptr': case 'float32ptr': if (!paramData) { @@ -233,6 +245,11 @@ export class WasmBinding { case 'float32': heapF32[offset32] = paramData as number; break; + case 'boolptr': + const boolArray = paramData as WasmCallArgumentTypeMap['boolptr']; + // This will work for both Uint8Array as well as ReadonlyArray + heapU8.subarray(offset8, offset8 + boolArray.length).set(paramData as Uint8Array); + break; case 'int32ptr': const int32Array = (paramData as WasmCallArgumentTypeMap['int32ptr'])!; heap32.subarray(offset32, offset32 + int32Array.length).set(int32Array); @@ -250,6 +267,7 @@ export class WasmBinding { // retrieve data parameters (in/inout) from emscripten heap after ccall() static ccallDeserialize(buffer: Uint8Array, offset: number[], params: WasmCallArgument[]) { const heapF32 = new Float32Array(buffer.buffer, buffer.byteOffset); + const heapU8 = new Uint8Array(buffer.buffer, buffer.byteOffset); for (let i = 0; i < params.length; i++) { const param = params[i]; @@ -270,6 +288,10 @@ export class WasmBinding { const float32Array = (paramData as Float32Array); float32Array.set(heapF32.subarray(offset32, offset32 + float32Array.length)); break; + case 'boolptr': + const boolArray = (paramData as Uint8Array); + boolArray.set(heapU8.subarray(offset8, offset8 + boolArray.length)); + break; default: throw new Error(`not supported parameter type: ${paramType}`); } diff --git a/src/wasm-build-config.json b/src/wasm-build-config.json index 7f2c964b..fbc78878 100644 --- a/src/wasm-build-config.json +++ b/src/wasm-build-config.json @@ -1,14 +1,23 @@ { "src": [ - "./wasm-ops/*.cpp" + "./wasm-ops/**/*.cpp" ], "exported_functions": [ "_malloc", "_free", + "_add_f32", + "_sub_f32", + "_mul_f32", + "_div_f32", + "_prelu_f32", + "_xor_u8", + "_or_u8", + "_and_u8", "_conv_f32", "_average_pool_f32", "_max_pool_f32", "_gemm_f32", + "_matmul_f32", "_batch_normalization_f32", "_sum_f32", "_softmax_f32" diff --git a/src/wasm-ops/binary-op.cpp b/src/wasm-ops/binary-op.cpp new file mode 100644 index 00000000..e9d588fb --- /dev/null +++ b/src/wasm-ops/binary-op.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "binary-op.h" +#include "common.h" + +// Wasm interop methods +void add_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + const float *input_1 = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *input_2 = PARAM_FLOAT_PTR(data, dataIndex[4]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void sub_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + const float *input_1 = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *input_2 = PARAM_FLOAT_PTR(data, dataIndex[4]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void mul_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + const float *input_1 = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *input_2 = PARAM_FLOAT_PTR(data, dataIndex[4]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void div_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + const float *input_1 = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *input_2 = PARAM_FLOAT_PTR(data, dataIndex[4]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void prelu_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + const float *input_1 = PARAM_FLOAT_PTR(data, dataIndex[1]); + const float *input_2 = PARAM_FLOAT_PTR(data, dataIndex[4]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void xor_u8(void *data) { + uint32_t *dataIndex = static_cast(data); + const uint8_t *input_1 = PARAM_BOOL_PTR(data, dataIndex[1]); + const uint8_t *input_2 = PARAM_BOOL_PTR(data, dataIndex[4]); + uint8_t *output = PARAM_BOOL_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void or_u8(void *data) { + uint32_t *dataIndex = static_cast(data); + const uint8_t *input_1 = PARAM_BOOL_PTR(data, dataIndex[1]); + const uint8_t *input_2 = PARAM_BOOL_PTR(data, dataIndex[4]); + uint8_t *output = PARAM_BOOL_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} +void and_u8(void *data) { + uint32_t *dataIndex = static_cast(data); + const uint8_t *input_1 = PARAM_BOOL_PTR(data, dataIndex[1]); + const uint8_t *input_2 = PARAM_BOOL_PTR(data, dataIndex[4]); + uint8_t *output = PARAM_BOOL_PTR(data, dataIndex[7]); + binary_imp(data, input_1, input_2, output); +} diff --git a/src/wasm-ops/binary-op.h b/src/wasm-ops/binary-op.h new file mode 100644 index 00000000..2f38b61b --- /dev/null +++ b/src/wasm-ops/binary-op.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include "common.h" +#include "utils/broadcast_utils.h" +#include "utils/shape_utils.h" +#include +#include + +extern "C" { +// Arithmetic ops +void add_f32(void *); +void sub_f32(void *); +void mul_f32(void *); +void div_f32(void *); +void prelu_f32(void *); + +// Logical ops +void xor_u8(void *); +void or_u8(void *); +void and_u8(void *); +} + +// Binary operator (with broadcasting) +template +void binary_imp(void *data, const T *input_1, const T *input_2, T *output) { + uint32_t *dataIndex = static_cast(data); + + // first input related + const int32_t rank_1 = PARAM_INT32(data, dataIndex[2]); + const int32_t *dims_1 = PARAM_INT32_PTR(data, dataIndex[3]); + std::vector dims1_vector; + if (rank_1 > 0) { + dims1_vector.resize(rank_1); + for (int32_t i = 0; i < rank_1; ++i) { + dims1_vector[i] = dims_1[i]; + } + } + + // second input related + const int32_t rank_2 = PARAM_INT32(data, dataIndex[5]); + const int32_t *dims_2 = PARAM_INT32_PTR(data, dataIndex[6]); + std::vector dims2_vector; + if (rank_2 > 0) { + dims2_vector.resize(rank_2); + for (int32_t i = 0; i < rank_2; ++i) { + dims2_vector[i] = dims_2[i]; + } + } + + // output related + const int32_t output_length = PARAM_INT32(data, dataIndex[8]); + const int32_t output_rank = PARAM_INT32(data, dataIndex[9]); + const int32_t *output_dims = PARAM_INT32_PTR(data, dataIndex[10]); + std::vector output_dims_vector; + if (output_rank != 0) { + output_dims_vector.resize(output_rank); + for (int32_t i = 0; i < output_rank; ++i) { + output_dims_vector[i] = output_dims[i]; + } + } + + // compute strides and some preprocessing + const std::vector strides_1 = + ShapeUtils::compute_strides(dims1_vector); + const std::vector strides_2 = + ShapeUtils::compute_strides(dims2_vector); + const std::vector output_strides = + ShapeUtils::compute_strides(output_dims_vector); + std::vector indices_1(rank_1); + std::vector indices_2(rank_2); + std::vector broadcasted_indices(output_strides.size()); + + // core functionality (with broadcasting) + for (size_t i = 0; i < output_length; ++i) { + ShapeUtils::offset_to_indices(output_strides, i, broadcasted_indices); + BroadcastUtils::broadcasted_to_original_indices(broadcasted_indices, + dims1_vector, indices_1); + auto offset1 = ShapeUtils::indices_to_offset(strides_1, indices_1); + BroadcastUtils::broadcasted_to_original_indices(broadcasted_indices, + dims2_vector, indices_2); + auto offset2 = ShapeUtils::indices_to_offset(strides_2, indices_2); + output[i] = BinaryOp::calc(input_1[offset1], input_2[offset2]); + } +} + +// Core op classes +class Add { +public: + template static T calc(const T &a, const T &b) { return a + b; } +}; + +class Sub { +public: + template static T calc(const T &a, const T &b) { return a - b; } +}; + +class Mul { +public: + template static T calc(const T &a, const T &b) { return a * b; } +}; + +class Div { +public: + template static T calc(const T &a, const T &b) { return a / b; } +}; + +class PRelu { +public: + template static T calc(const T &a, const T &b) { + return a >= 0 ? a : a * b; + } +}; + +class Xor { +public: + template static T calc(const T &a, const T &b) { return a ^ b; } +}; + +class Or { +public: + template static T calc(const T &a, const T &b) { return a || b; } +}; + +class And { +public: + template static T calc(const T &a, const T &b) { return a && b; } +}; diff --git a/src/wasm-ops/common.h b/src/wasm-ops/common.h index 2217ade4..0bd87ee7 100644 --- a/src/wasm-ops/common.h +++ b/src/wasm-ops/common.h @@ -18,6 +18,7 @@ static_assert(sizeof(int) == sizeof(int32_t), #define PARAM_BOOL(data, offset) (!!PARAM_VALUE(data, offset, uint8_t)) #define PARAM_INT32(data, offset) PARAM_VALUE(data, offset, int32_t) #define PARAM_FLOAT(data, offset) PARAM_VALUE(data, offset, float) +#define PARAM_BOOL_PTR(data, offset) PARAM_PTR(data, offset, uint8_t) #define PARAM_INT32_PTR(data, offset) PARAM_PTR(data, offset, int32_t) #define PARAM_FLOAT_PTR(data, offset) PARAM_PTR(data, offset, float) diff --git a/src/wasm-ops/matmul.cpp b/src/wasm-ops/matmul.cpp new file mode 100644 index 00000000..caa1f97a --- /dev/null +++ b/src/wasm-ops/matmul.cpp @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "matmul.h" +#include "common.h" +#include "utils/broadcast_utils.h" +#include "utils/shape_utils.h" +#include + +// Wasm interop method +void matmul_f32(void *data) { + uint32_t *dataIndex = static_cast(data); + uint32_t const argc = dataIndex[0]; + const float *input_1 = PARAM_FLOAT_PTR(data, dataIndex[1]); + const int32_t *dims_1 = PARAM_INT32_PTR(data, dataIndex[2]); + const int32_t rank_1 = PARAM_INT32(data, dataIndex[3]); + const float *input_2 = PARAM_FLOAT_PTR(data, dataIndex[4]); + const int32_t *dims_2 = PARAM_INT32_PTR(data, dataIndex[5]); + const int32_t rank_2 = PARAM_INT32(data, dataIndex[6]); + float *output = PARAM_FLOAT_PTR(data, dataIndex[7]); + const int32_t output_length = PARAM_INT32(data, dataIndex[8]); + const int32_t *output_dims = PARAM_INT32_PTR(data, dataIndex[9]); + const int32_t output_rank = PARAM_INT32(data, dataIndex[10]); + matmul_f32_imp(input_1, dims_1, rank_1, input_2, dims_2, rank_2, output, + output_length, output_dims, output_rank); +} + +// Core operator implementation +void matmul_f32_imp(const float *input_1, const int32_t *dims_1, + const int32_t rank_1, const float *input_2, + const int32_t *dims_2, const int32_t rank_2, float *output, + const int32_t output_length, const int32_t *output_dims, + const int32_t output_rank) { + int32_t M = dims_1[rank_1 - 2]; + int32_t K = dims_1[rank_1 - 1]; + int32_t N = dims_2[rank_2 - 1]; + + // 2D matrices only + if (output_rank == 2) { + matmul2D_f32(input_1, input_2, output, M, K, N); + return; + } + + // multi-D matrices + else { + const int32_t num_matrices = (output_length) / (M * N); + + const float *input_1_traverse; + const float *input_2_traverse; + + std::vector dims_1_vector(rank_1); + for (int32_t r = 0; r < rank_1; ++r) { + dims_1_vector[r] = dims_1[r]; + } + const std::vector strides_1 = + ShapeUtils::compute_strides(dims_1_vector); + + std::vector dims_2_vector(rank_2); + for (int32_t r = 0; r < rank_2; ++r) { + dims_2_vector[r] = dims_2[r]; + } + const std::vector strides_2 = + ShapeUtils::compute_strides(dims_2_vector); + + std::vector broadcasted_indices = + std::vector(output_rank); + broadcasted_indices[output_rank - 1] = 0; + broadcasted_indices[output_rank - 2] = 0; + + std::vector original_indices_1(rank_1); + int32_t original_offset_1; + std::vector original_indices_2(rank_2); + int32_t original_offset_2; + int32_t offset_remainder; + + for (int32_t i = 0; i < num_matrices; i++) { + // Compute broadcasted_indices for this offset + int32_t offset_remainder = i; + for (int32_t j = output_rank - 3; j >= 0; j--) { + broadcasted_indices[j] = offset_remainder % output_dims[j]; + offset_remainder = floor(offset_remainder / output_dims[j]); + } + + // This matrix is 2D, so no need to find the start_offset + if (rank_1 == 2) { + input_1_traverse = input_1; + } + // This matrix is not 2D, so no need to find appropriate start_offset + else { + BroadcastUtils::broadcasted_to_original_indices( + broadcasted_indices, dims_1_vector, original_indices_1); + original_offset_1 = + ShapeUtils::indices_to_offset(strides_1, original_indices_1); + input_1_traverse = input_1 + original_offset_1; + } + + // This matrix is 2D, so no need to find the start_offset + if (rank_2 == 2) { + input_2_traverse = input_2; + } + // This matrix is not 2D, so no need to find appropriate start_offset + else { + BroadcastUtils::broadcasted_to_original_indices( + broadcasted_indices, dims_2_vector, original_indices_2); + original_offset_2 = + ShapeUtils::indices_to_offset(strides_2, original_indices_2); + input_2_traverse = input_2 + original_offset_2; + } + + // process this 2D component alone + matmul2D_f32(input_1_traverse, input_2_traverse, output + (i * M * N), M, + K, N); + } + } +} + +// Core functionality implementation +void matmul2D_f32(const float *input_1, const float *input_2, float *output, + const int32_t M, const int32_t K, const int32_t N) { + for (int32_t row = 0; row < M; ++row) { + for (int32_t col = 0; col < N; ++col) { + float sum = 0; + for (int32_t traverse = 0; traverse < K; ++traverse) { + sum += input_1[row * K + traverse] * input_2[traverse * N + col]; + } + output[row * N + col] = sum; + } + } +} diff --git a/src/wasm-ops/matmul.h b/src/wasm-ops/matmul.h new file mode 100644 index 00000000..fb5db08d --- /dev/null +++ b/src/wasm-ops/matmul.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +extern "C" { +void matmul_f32(void *); +void matmul_f32_imp(const float *, const int32_t *, const int32_t, + const float *, const int32_t *, const int32_t, float *, + const int32_t, const int32_t *, const int32_t); +void matmul2D_f32(const float *, const float *, float *, const int32_t, + const int32_t, const int32_t); +} diff --git a/src/wasm-ops/softmax.cpp b/src/wasm-ops/softmax.cpp index 23e66e59..213be29b 100644 --- a/src/wasm-ops/softmax.cpp +++ b/src/wasm-ops/softmax.cpp @@ -15,7 +15,7 @@ void softmax_f32(void *data) { } // Core operator implementation -void softmax_f32_imp(float *X, float *Y, int N, int D) { +void softmax_f32_imp(float *X, float *Y, int32_t N, int32_t D) { for (size_t i = 0; i < N; i++) { // find row offset int offset = i * D; diff --git a/src/wasm-ops/utils/broadcast_utils.cpp b/src/wasm-ops/utils/broadcast_utils.cpp new file mode 100644 index 00000000..982ae8e2 --- /dev/null +++ b/src/wasm-ops/utils/broadcast_utils.cpp @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "broadcast_utils.h" + +std::vector BroadcastUtils::broadcasted_to_original_indices( + const std::vector &broadcasted_indices, + const std::vector &dims) { + const auto rank = dims.size(); + if (rank == 0) { + return std::vector(); + } + std::vector original_indices(rank); + BroadcastUtils::broadcasted_to_original_indices(broadcasted_indices, dims, + original_indices); + return original_indices; +} + +void BroadcastUtils::broadcasted_to_original_indices( + const std::vector &broadcasted_indices, + const std::vector &dims, std::vector &original_indices) { + const auto rank = dims.size(); + if (rank == 0) { + return; + } + auto offset = broadcasted_indices.size() - dims.size(); + for (size_t i = 0; i < rank; ++i) { + original_indices[i] = broadcasted_indices[offset + i] % dims[i]; + } +} diff --git a/src/wasm-ops/utils/broadcast_utils.h b/src/wasm-ops/utils/broadcast_utils.h new file mode 100644 index 00000000..b219be8a --- /dev/null +++ b/src/wasm-ops/utils/broadcast_utils.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +namespace BroadcastUtils { +std::vector +broadcasted_to_original_indices(const std::vector &broadcasted_indices, + const std::vector &dims); + +// Fills in values in the original_indices vector. Assumes it is of the required +// size. +void broadcasted_to_original_indices( + const std::vector &broadcasted_indices, + const std::vector &dims, std::vector &original_indices); +}; // namespace BroadcastUtils diff --git a/src/wasm-ops/utils/shape_utils.cpp b/src/wasm-ops/utils/shape_utils.cpp new file mode 100644 index 00000000..bd068b83 --- /dev/null +++ b/src/wasm-ops/utils/shape_utils.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include "shape_utils.h" +#include + +size_t ShapeUtils::size_from_dims(const std::vector &dims) { + auto rank = dims.size(); + if (rank == 0) { + return 1; + } + if (rank == 1) { + return dims[0]; + } + size_t size = 1; + for (auto &e : dims) { + size *= e; + } + return size; +} + +std::vector +ShapeUtils::compute_strides(const std::vector &dims) { + auto rank = dims.size(); + if (rank == 0 || rank == 1) { + std::vector strides(1, 1); + return strides; + } + std::vector strides(rank); + ShapeUtils::compute_strides(dims, strides); + return strides; +} + +void ShapeUtils::compute_strides(const std::vector &dims, + std::vector &strides) { + auto rank = dims.size(); + if (rank == 0 || rank == 1) { + strides[0] = 1; + return; + } + strides[rank - 1] = 1; + strides[rank - 2] = dims[rank - 1]; + for (int32_t i = rank - 3; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } +} + +size_t ShapeUtils::indices_to_offset(const std::vector &strides, + const std::vector &indices) { + size_t offset = 0; + for (size_t i = 0; i < indices.size(); ++i) { + offset += strides[i] * indices[i]; + } + return offset; +} + +std::vector +ShapeUtils::offset_to_indices(const std::vector &strides, + size_t offset) { + auto rank = strides.size(); + if (rank == 0) { + return std::vector(); + } + if (rank == 1) { + return std::vector(1, offset * strides[0]); + } + std::vector indices(rank); + ShapeUtils::offset_to_indices(strides, offset, indices); + return indices; +} + +void ShapeUtils::offset_to_indices(const std::vector &strides, + size_t offset, + std::vector &indices) { + auto rank = strides.size(); + if (rank == 0) { + return; + } + if (rank == 1) { + indices[0] = offset * strides[0]; + return; + } + for (size_t i = 0; i < indices.size() - 1; ++i) { + indices[i] = floor(offset / strides[i]); + offset -= indices[i] * strides[i]; + } + indices[indices.size() - 1] = offset; +} diff --git a/src/wasm-ops/utils/shape_utils.h b/src/wasm-ops/utils/shape_utils.h new file mode 100644 index 00000000..2d138136 --- /dev/null +++ b/src/wasm-ops/utils/shape_utils.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once + +#include + +namespace ShapeUtils { +size_t size_from_dims(const std::vector &dims); +std::vector compute_strides(const std::vector &dims); +// Fills in values in the strides vector. Assumes it is of the required size. +void compute_strides(const std::vector &dims, + std::vector &strides); +size_t indices_to_offset(const std::vector &strides, + const std::vector &indices); +std::vector offset_to_indices(const std::vector &strides, + size_t offset); +// Fills in values in the indices vector. Assumes it is of the required size. +void offset_to_indices(const std::vector &strides, size_t offset, + std::vector &indices); +}; // namespace ShapeUtils diff --git a/test/unittest-whitelist.jsonc b/test/unittest-whitelist.jsonc index 1344198b..a61f45c7 100644 --- a/test/unittest-whitelist.jsonc +++ b/test/unittest-whitelist.jsonc @@ -478,6 +478,40 @@ // Check in node tests that have native Wasm implementations // (i.e.) not tests that rely on the fallback cpu implementations // Use the 'cpu' level of node tests to test those implementations + "test_add_bcast", + "test_add", + "test_sub_bcast", + "test_sub_example", + "test_sub", + "test_mul_bcast", + "test_mul_example", + "test_mul", + "test_div_bcast", + "test_div_example", + "test_div", + "test_xor_bcast3v1d", + "test_xor_bcast3v2d", + "test_xor_bcast4v2d", + "test_xor_bcast4v3d", + "test_xor_bcast4v4d", + "test_xor2d", + "test_xor3d", + "test_xor4d", + "test_or_bcast3v1d", + "test_or_bcast3v2d", + "test_or_bcast4v2d", + "test_or_bcast4v3d", + "test_or_bcast4v4d", + "test_and_bcast3v1d", + "test_and_bcast3v2d", + "test_and_bcast4v2d", + "test_and_bcast4v3d", + "test_and_bcast4v4d", + "test_and2d", + "test_and3d", + "test_and4d", + "test_prelu_broadcast", + "test_prelu_example", "test_basic_conv_with_padding", "test_basic_conv_without_padding", "test_batchnorm_epsilon", @@ -487,6 +521,9 @@ "test_conv_with_strides_padding", "test_gemm_nobroadcast", "test_gemm_broadcast", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", "test_softmax_axis_0", "test_softmax_axis_1", "test_softmax_axis_2", @@ -526,7 +563,15 @@ // (i.e.) not tests that rely on the fallback cpu implementations // Use the 'cpu' level of node tests to test those implementations "conv.jsonc", - "softmax.jsonc" + "softmax.jsonc", + "add.jsonc", + "sub.jsonc", + "mul.jsonc", + "div.jsonc", + "and.jsonc", + "or.jsonc", + "xor.jsonc", + "matmul.jsonc" ] } }