From 2d4ddf941f8cbadec115dc98409d993755e41fb3 Mon Sep 17 00:00:00 2001 From: xueyun zhu Date: Thu, 6 May 2021 16:00:15 -0700 Subject: [PATCH 1/4] adding conv fuse logic --- .../onnxjs/backends/webgl/ops/conv-pack.ts | 8 +- js/web/lib/onnxjs/backends/webgl/ops/conv.ts | 41 ++++--- .../onnxjs/backends/webgl/ops/fuse_utils.ts | 23 ++++ .../onnxjs/backends/webgl/ops/matmul-pack.ts | 9 +- js/web/lib/onnxjs/graph.ts | 111 +++++++++++------- js/web/lib/onnxjs/ops/conv.ts | 4 +- js/web/lib/onnxjs/ops/matmul.ts | 7 +- 7 files changed, 138 insertions(+), 65 deletions(-) create mode 100644 js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts index 2f9a353935938..1976494bddfe5 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Licensed under the MIT license. +import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; import {Conv} from '../../../ops/conv'; import {Tensor} from '../../../tensor'; @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv { const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides); const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); const matmul = new WebGLMatMulPacked(); + if (!!this.activation) { + const attributes = new Attribute(undefined); + attributes.set('__internal_activation', 'string', (this.activation)); + matmul.initialize(attributes); + } const reshape = new WebGLReshapePacked(); // shape for kernel reshape const shape = diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts index 0fa0d61a34414..5600b7c98e856 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Licensed under the MIT license. import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; @@ -10,7 +10,9 @@ import {getGlsl} from '../glsl-source'; import {WebGLInferenceHandler} from '../inference-handler'; import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types'; import {WebGLContext} from '../webgl-context'; + import {WebGLConvPacked} from './conv-pack'; +import {getActicationSnippet} from './fuse_utils'; export class WebGLConv extends Conv { unpackedGroupedConvImpl: WebGLUnpackedGroupedConv; @@ -66,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { const hasBias = inputs.length > 2; - const processBias = hasBias ? 'dotProd += getBias(output_channel);' : ''; + const processBias = hasBias ? `value += getBias(output_channel);` : ``; const xShape = inputs[0].dims.slice(); const wShape = inputs[1].dims.slice(); const outputChannelsPerGroup = wShape[0] / this.group; @@ -85,10 +87,12 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides); const glsl = getGlsl(handler.session.backend.glContext.version); + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); + const shaderSource = ` const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]}); const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]}); - + ${activationFunction} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; @@ -96,7 +100,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { ivec2 xRCCorner = coords.zw * strides - pads; int group_id = output_channel / ${outputChannelsPerGroup}; - float dotProd = 0.0; + float value = 0.0; for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) { int input_channel = group_id * ${wShape[1]} + wInChannel; for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) { @@ -114,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { float xVal = getX(batch, input_channel, xWidth, xHeight); float wVal = getW(output_channel, wInChannel, wWidth, wHeight); - dotProd += xVal*wVal; + value += xVal*wVal; } } } ${processBias} - ${glsl.output} = vec4(dotProd, .0, .0, .0); + ${applyActivation} + ${glsl.output} = vec4(value, .0, .0, .0); } `; return { @@ -215,7 +220,6 @@ export class WebGLUnpackedConv extends Conv { let blend = false; for (let k = 0; k < sharedDim; k += sharedDimReadSize) { Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`); - if (k === sharedDimReadSize) { blend = true; gl.enable(gl.BLEND); @@ -248,6 +252,7 @@ export class WebGLUnpackedConv extends Conv { const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4); const outputLayout = inferenceHandler.createTextureLayoutFromShape( im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3}); + const shaderSource = ` const int XC = ${xshape[1]}; const int XH = ${xshape[2]}; @@ -263,13 +268,12 @@ export class WebGLUnpackedConv extends Conv { const int KHKW = KH*KW; const int XCKHKW = XC * KHKW; const int outputChannels = 4; - vec4 process(int indices[${rank}]) { int b = indices[0]; // batch size int oh = indices[1] * strideH - padH; //output height int ow = indices[2] * strideW - padW; //output width int p = indices[3] * outputChannels; //patch - vec4 v = vec4(0.0); + vec4 value = vec4(0.0); for(int i=0; i < outputChannels; ++i) { if(p < XCKHKW) { int patchC = p / KHKW; @@ -286,12 +290,12 @@ export class WebGLUnpackedConv extends Conv { xh2 < XH && xw2 >= 0 && xw2 < XW) { - v[i] = _X(x); + value[i] = _X(x); } } ++p; } - return v; + return value; } `; return { @@ -321,7 +325,7 @@ export class WebGLUnpackedConv extends Conv { const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape); const initValue = (inputs.length < 3) ? '0.0' : '_B(b)'; const sharedDim = im2colLayout.shape[3]; - const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported; + const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation; const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ? this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) : sharedDim; @@ -329,8 +333,12 @@ export class WebGLUnpackedConv extends Conv { if (inputs.length === 3) { samplers.push('B'); } + + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const shaderSource = ` + ${activationFunction} float process(int indices[${rank}]) { int b[1]; b[0] = indices[1]; @@ -341,15 +349,16 @@ export class WebGLUnpackedConv extends Conv { int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${ im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset; int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset; - float sum = sharedDimOffset == 0 ? ${initValue} : 0.0; + float value = sharedDimOffset == 0 ? ${initValue} : 0.0; for (int i = 0; i < ${sharedDimReadSize}; ++i) { vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height}); vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height}); - sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords)); + value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords)); ++im2colOffset; ++kernelOffset; } - return sum; + ${applyActivation} + return value; }`; return { inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout], @@ -357,7 +366,7 @@ export class WebGLUnpackedConv extends Conv { shaderSource, samplers, variables: [{name: 'sharedDimOffset', type: 'int'}], - params: {sharedDim, sharedDimReadSize} + params: {'sharedDim': sharedDim, 'sharedDimReadSize': sharedDimReadSize} }; } static prepKernelForDotProduct(shape: number[], group: number, channels: number, kernel: Float32Array): Float32Array { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts new file mode 100644 index 0000000000000..40eb627315ab3 --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts @@ -0,0 +1,23 @@ +import {glslRelu, glslSigmoid} from './unary-op'; + +export function getActicationSnippet(activation: string) { + let activationFunction = ''; + let activationName = ''; + switch (activation) { + case 'Relu': + activationName = glslRelu().name; + activationFunction = glslRelu().body; + break; + case 'Sigmoid': + activationName = glslSigmoid().name; + activationFunction = glslSigmoid().body; + break; + default: + activationName = ''; + activationFunction = ''; + } + const applyActivation = activation ? ` + value = ${activationName}(value);` : + ''; + return {activationFunction, applyActivation}; +} diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index a508debcef52f..280b2364f433e 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Licensed under the MIT license. import {MatMul} from '../../../ops/matmul'; import {Tensor} from '../../../tensor'; import {BroadcastUtil} from '../../../util'; import {WebGLInferenceHandler} from '../inference-handler'; import {ProgramInfo, RunData, WebGLOperator} from '../types'; +import {getActicationSnippet} from './fuse_utils'; export class WebGLMatMulPacked extends MatMul implements WebGLOperator { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { @@ -13,7 +14,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { } createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { const hasBias = inputs.length > 2; - const processBias = hasBias ? 'value += vec4(getBias(a[0]*2).xx, getBias(a[0]*2).yy);' : ''; + const processBias = hasBias ? `value += vec4(getBias(a[0]*2).xx, getBias(a[0]*2).yy);` : ``; const aShape = inputs[0].dims; const bShape = inputs[1].dims; const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { const aRank = aShape.length; const bRank = bShape.length; const sharedDim = aShape[aShape.length - 1]; + + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); // TODO:fix broadcasting const shaderSource = ` + ${activationFunction} vec4 process(int indices[${rank}]) { int a[${aRank}]; int b[${bRank}]; @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba; } ${processBias} + ${applyActivation} return value; }`; return { diff --git a/js/web/lib/onnxjs/graph.ts b/js/web/lib/onnxjs/graph.ts index 23dfd803c92ba..6476e4e6eee98 100644 --- a/js/web/lib/onnxjs/graph.ts +++ b/js/web/lib/onnxjs/graph.ts @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Licensed under the MIT license. import {onnx} from 'onnx-proto'; import {Attribute} from './attribute'; -import {onnxruntime} from './ort-schema/ort-generated'; +import {onnxruntime} from './ortSchema/ort_generated'; import ortFbs = onnxruntime.experimental.fbs; import {Tensor} from './tensor'; import {LongUtil, ProtoUtil} from './util'; export declare namespace Graph { export interface Shape { - readonly dims: readonly number[]; + readonly dims: ReadonlyArray; } export interface ValueType { readonly tensorType: Tensor.DataType; @@ -25,7 +25,7 @@ export declare namespace Graph { readonly from: number; // indices to the Nodes where the values go to. - readonly to: readonly number[]; + readonly to: ReadonlyArray; // value type specification. empty for non-input values. readonly type?: ValueType; @@ -38,10 +38,10 @@ export declare namespace Graph { readonly opType: string; // indices to the Values where the inputs come from. - readonly inputs: readonly number[]; + readonly inputs: ReadonlyArray; // indices to the Values where the outpus go to. - readonly outputs: readonly number[]; + readonly outputs: ReadonlyArray; // the attributes that used by the operator readonly attributes: Attribute; @@ -53,6 +53,8 @@ export declare namespace Graph { export interface Transformer { removeAllIdentityNodes(): void; removeAllDropoutNodes(): void; + + fuseConvActivationNodes(): void; // TODO: add generic functions to manipulate the graph } @@ -62,17 +64,16 @@ export declare namespace Graph { } } -// eslint-disable-next-line @typescript-eslint/no-redeclare export interface Graph { - getInputIndices(): readonly number[]; - getInputNames(): readonly string[]; - getOutputIndices(): readonly number[]; - getOutputNames(): readonly string[]; - getValues(): readonly Graph.Value[]; - getNodes(): readonly Graph.Node[]; + getInputIndices(): ReadonlyArray; + getInputNames(): ReadonlyArray; + getOutputIndices(): ReadonlyArray; + getOutputNames(): ReadonlyArray; + getValues(): ReadonlyArray; + getNodes(): ReadonlyArray; } -// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-redeclare +// tslint:disable-next-line:variable-name export const Graph = { /** * construct a graph from a graph protobuf type @@ -156,27 +157,27 @@ class GraphImpl implements Graph, Graph.Transformer { this.checkIsAcyclic(); } - getInputIndices(): readonly number[] { + getInputIndices(): ReadonlyArray { return this._allInputIndices; } - getInputNames(): readonly string[] { + getInputNames(): ReadonlyArray { return this._allInputNames; } - getOutputIndices(): readonly number[] { + getOutputIndices(): ReadonlyArray { return this._allOutputIndices; } - getOutputNames(): readonly string[] { + getOutputNames(): ReadonlyArray { return this._allOutputNames; } - getValues(): readonly Graph.Value[] { + getValues(): ReadonlyArray { return this._allData; } - getNodes(): readonly Graph.Node[] { + getNodes(): ReadonlyArray { return this._nodes; } @@ -306,10 +307,10 @@ class GraphImpl implements Graph, Graph.Transformer { // operator and ignore the node from the graph if (nodeProto.opType === 'Constant') { if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) { - throw new Error('missing attributes or missing tensor value in attributes for this Constant operator'); + throw new Error(`missing attributes or missing tensor value in attributes for this Constant operator`); } if (!nodeProto.output || nodeProto.output.length !== 1) { - throw new Error('missing output or incorrect number of outputs for this Constant operator'); + throw new Error(`missing output or incorrect number of outputs for this Constant operator`); } node.outputs.pop(); node.executeNode = false; @@ -371,12 +372,12 @@ class GraphImpl implements Graph, Graph.Transformer { if (valueType !== ortFbs.TypeInfoValue.tensor_type) { throw new Error('Unexpected value type for the nodeArg.'); } - const valueInfo = graph.nodeArgs(j)!.type()!.value(new ortFbs.TensorTypeAndShape())!; - const type = ProtoUtil.tensorDataTypeFromProto(valueInfo.elemType()); - const shape = valueInfo.shape()!; + const valueInfo = graph.nodeArgs(j)?.type()?.value(new ortFbs.TensorTypeAndShape()); + const type = ProtoUtil.tensorDataTypeFromProto(valueInfo?.elemType()!); + const shape = valueInfo?.shape(); const dims = []; - for (let k = 0; k < shape.dimLength()!; k++) { - dims.push(LongUtil.longToNumber(shape.dim(k)!.value()!.dimValue()!)); + for (let k = 0; k < shape?.dimLength()!; k++) { + dims.push(LongUtil.longToNumber(shape?.dim(k)?.value()?.dimValue()!)); } value.type = {shape: {dims}, tensorType: type}; const currentIndex = this._allData.push(value) - 1; @@ -474,10 +475,10 @@ class GraphImpl implements Graph, Graph.Transformer { // operator and ignore the node from the graph if (nodeProto.opType() === 'Constant') { if (nodeProto.attributesLength() !== 1 || !nodeProto.attributes(0)!.t()) { - throw new Error('missing attributes or missing tensor value in attributes for this Constant operator'); + throw new Error(`missing attributes or missing tensor value in attributes for this Constant operator`); } if (nodeProto.outputsLength() !== 1) { - throw new Error('missing output or incorrect number of outputs for this Constant operator'); + throw new Error(`missing output or incorrect number of outputs for this Constant operator`); } node.outputs.pop(); node.executeNode = false; @@ -491,13 +492,13 @@ class GraphImpl implements Graph, Graph.Transformer { // scan node's inputs for (let i = 0; i < this._nodes.length; i++) { const node = this._nodes[i]; - const nodeProto = graph.nodes(i)!; + const nodeProto = graph.nodes(i); - if (nodeProto.inputsLength() === 0) { + if (nodeProto?.inputsLength() === 0) { throw new Error(`missing input for node: ${nodeProto.name}`); } - for (let j = 0; j < nodeProto.inputsLength()!; j++) { - const input = nodeProto.inputs(j)!; + for (let j = 0; j < nodeProto?.inputsLength()!; j++) { + const input = nodeProto?.inputs(j)!; const dataIndex = dataIndices.get(input); if (typeof dataIndex === 'undefined') { throw new Error(`unrecognized input '${input}' for node: ${nodeProto!.name()}`); @@ -536,15 +537,15 @@ class GraphImpl implements Graph, Graph.Transformer { this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => { const data = this._allData[outgoingEdgeIndex]; if (typeof data.tensor !== 'undefined') { - throw new Error('node outputs should not be initialized'); + throw new Error(`node outputs should not be initialized`); } if (data._from !== nodeIndex) { - throw new Error('from property of the Value object doesn\'t match index of Node being processed'); + throw new Error(`from property of the Value object doesn't match index of Node being processed`); } data._to.forEach((downstreamNodeIndex) => { // back edge found - cyclic if (nodesState[downstreamNodeIndex] === 'gray') { - throw new Error('model graph is cyclic'); + throw new Error(`model graph is cyclic`); } // tree edge found - continue processing by adding it to stack else if (nodesState[downstreamNodeIndex] === 'white') { @@ -560,6 +561,7 @@ class GraphImpl implements Graph, Graph.Transformer { // apply common transform this.removeAllIdentityNodes(); this.removeAllDropoutNodes(); + this.fuseConvActivationNodes(); // apply initializer specific transform if (graphInitializer) { @@ -658,12 +660,12 @@ class GraphImpl implements Graph, Graph.Transformer { private deleteNode(nodeIndex: number) { const node = this._nodes[nodeIndex]; if (node.inputs.length > 1) { - throw new Error('Node deletion with multiple inputs is not supported. '); + throw new Error(`Node deletion with multiple inputs is not supported. `); } if (node.outputs.length > 1) { for (let i = 1; i < node.outputs.length; i++) { if (this._allData[node.outputs[i]].to.length > 0) { - throw new Error('Node deletion with more than one output connected to other nodes is not supported. '); + throw new Error(`Node deletion with more than one output connected to other nodes is not supported. `); } } } @@ -678,7 +680,7 @@ class GraphImpl implements Graph, Graph.Transformer { const delIndex = this._allData[inputValueIndex].to.indexOf(nodeIndex); // should not happen if (delIndex === -1) { - throw new Error('The Value object doesn\'t have the current Node in it\'s \'to\' property '); + throw new Error(`The Value object doesn't have the current Node in it's 'to' property `); } this._allData[inputValueIndex].to.splice(delIndex, 1); @@ -697,7 +699,7 @@ class GraphImpl implements Graph, Graph.Transformer { const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex); // should not happen if (replaceIndex === -1) { - throw new Error('The Node object doesn\'t have the output Value in it\'s \'inputs\' property '); + throw new Error(`The Node object doesn't have the output Value in it's 'inputs' property `); } this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex; this._allData[inputValueIndex].to.push(nodeIndex); @@ -712,14 +714,14 @@ class GraphImpl implements Graph, Graph.Transformer { if (node.opType === 'Dropout') { // the node should have exactly 1 input and 1 or 2 outputs if (node.inputs.length !== 1) { - throw new Error('Dropout nodes should only contain one input. '); + throw new Error(`Dropout nodes should only contain one input. `); } if (node.outputs.length !== 1 && node.outputs.length !== 2) { - throw new Error('Dropout nodes should contain either 1 or 2 output(s)'); + throw new Error(`Dropout nodes should contain either 1 or 2 output(s)`); } // the second output should not be referenced by any other node if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) { - throw new Error('Dropout nodes\'s second output should not be referenced by other nodes'); + throw new Error(`Dropout nodes's second output should not be referenced by other nodes`); } this.deleteNode(nodeIndex); } @@ -737,4 +739,27 @@ class GraphImpl implements Graph, Graph.Transformer { nodeIndex++; } } + + isActivation(n: Node): boolean { + switch (n.opType) { + // TODO: add other activation methods + case 'Relu': + case 'Sigmoid': + return true; + default: + return false; + } + } + + fuseConvActivationNodes() { + for (const node of this._nodes) { + if (node.opType === 'Conv') { + const next = this._allData[node.outputs[0]]._to; + if (next.length === 1 && this.isActivation(this._nodes[next[0]])) { + node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType)); + this.deleteNode(next[0]); + } + } + } + } } diff --git a/js/web/lib/onnxjs/ops/conv.ts b/js/web/lib/onnxjs/ops/conv.ts index d96c0b690eda1..f0ea235b3fd97 100644 --- a/js/web/lib/onnxjs/ops/conv.ts +++ b/js/web/lib/onnxjs/ops/conv.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Licensed under the MIT license. import {Attribute} from '../attribute'; import {InferenceHandler} from '../backend'; @@ -17,6 +17,7 @@ export abstract class Conv implements Operator { this.kernelShape = attributes.getInts('kernel_shape', []); this.pads = attributes.getInts('pads', [0, 0, 0, 0]); this.strides = attributes.getInts('strides', [1, 1]); + this.activation = attributes.getString('__internal_activation', ''); } checkInputs(inputs: Tensor[]): boolean { @@ -88,4 +89,5 @@ export abstract class Conv implements Operator { protected kernelShape: number[]; protected pads: number[]; protected strides: number[]; + protected activation: string; } diff --git a/js/web/lib/onnxjs/ops/matmul.ts b/js/web/lib/onnxjs/ops/matmul.ts index 567dd6d239078..91ab027a876c9 100644 --- a/js/web/lib/onnxjs/ops/matmul.ts +++ b/js/web/lib/onnxjs/ops/matmul.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. +// Licensed under the MIT license. import {Attribute} from '../attribute'; import {InferenceHandler} from '../backend'; @@ -9,7 +9,9 @@ import {Tensor} from '../tensor'; export abstract class MatMul implements Operator { abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; - initialize(_attributes: Attribute): void {} + initialize(attributes: Attribute): void { + this.activation = attributes.getString('__internal_activation', ''); + } checkInputs(inputs: Tensor[]): boolean { if (!inputs || inputs.length !== 2) { @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator { return true; } + protected activation: string; } From 40eac2814968e4f64e9c5c309c5446c36b69fb84 Mon Sep 17 00:00:00 2001 From: xueyun zhu Date: Thu, 6 May 2021 16:27:39 -0700 Subject: [PATCH 2/4] fixing merge --- .../onnxjs/backends/webgl/ops/conv-pack.ts | 2 +- js/web/lib/onnxjs/backends/webgl/ops/conv.ts | 4 +- .../onnxjs/backends/webgl/ops/fuse_utils.ts | 1 + .../onnxjs/backends/webgl/ops/matmul-pack.ts | 4 +- js/web/lib/onnxjs/graph.ts | 90 +++++++++---------- js/web/lib/onnxjs/ops/conv.ts | 2 +- js/web/lib/onnxjs/ops/matmul.ts | 2 +- 7 files changed, 52 insertions(+), 53 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts index 1976494bddfe5..d12df31323a99 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. +// Licensed under the MIT License. import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts index 5600b7c98e856..f2a5359291395 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. +// Licensed under the MIT License. import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; @@ -366,7 +366,7 @@ export class WebGLUnpackedConv extends Conv { shaderSource, samplers, variables: [{name: 'sharedDimOffset', type: 'int'}], - params: {'sharedDim': sharedDim, 'sharedDimReadSize': sharedDimReadSize} + params: {sharedDim, sharedDimReadSize} }; } static prepKernelForDotProduct(shape: number[], group: number, channels: number, kernel: Float32Array): Float32Array { diff --git a/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts index 40eb627315ab3..d2708a574e47b 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts @@ -13,6 +13,7 @@ export function getActicationSnippet(activation: string) { activationFunction = glslSigmoid().body; break; default: + // TODO: adding other activations that can be fused. activationName = ''; activationFunction = ''; } diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index 280b2364f433e..cef9ff6bc5d7c 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. +// Licensed under the MIT License. import {MatMul} from '../../../ops/matmul'; import {Tensor} from '../../../tensor'; @@ -14,7 +14,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { } createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { const hasBias = inputs.length > 2; - const processBias = hasBias ? `value += vec4(getBias(a[0]*2).xx, getBias(a[0]*2).yy);` : ``; + const processBias = hasBias ? 'value += vec4(getBias(a[0]*2).xx, getBias(a[0]*2).yy);' : ''; const aShape = inputs[0].dims; const bShape = inputs[1].dims; const outputShape = BroadcastUtil.calcShape(aShape, bShape, true); diff --git a/js/web/lib/onnxjs/graph.ts b/js/web/lib/onnxjs/graph.ts index 6476e4e6eee98..af281d104703a 100644 --- a/js/web/lib/onnxjs/graph.ts +++ b/js/web/lib/onnxjs/graph.ts @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. +// Licensed under the MIT License. import {onnx} from 'onnx-proto'; import {Attribute} from './attribute'; -import {onnxruntime} from './ortSchema/ort_generated'; +import {onnxruntime} from './ort-schema/ort-generated'; import ortFbs = onnxruntime.experimental.fbs; import {Tensor} from './tensor'; import {LongUtil, ProtoUtil} from './util'; export declare namespace Graph { export interface Shape { - readonly dims: ReadonlyArray; + readonly dims: readonly number[]; } export interface ValueType { readonly tensorType: Tensor.DataType; @@ -25,7 +25,7 @@ export declare namespace Graph { readonly from: number; // indices to the Nodes where the values go to. - readonly to: ReadonlyArray; + readonly to: readonly number[]; // value type specification. empty for non-input values. readonly type?: ValueType; @@ -38,10 +38,10 @@ export declare namespace Graph { readonly opType: string; // indices to the Values where the inputs come from. - readonly inputs: ReadonlyArray; + readonly inputs: readonly number[]; // indices to the Values where the outpus go to. - readonly outputs: ReadonlyArray; + readonly outputs: readonly number[]; // the attributes that used by the operator readonly attributes: Attribute; @@ -53,8 +53,6 @@ export declare namespace Graph { export interface Transformer { removeAllIdentityNodes(): void; removeAllDropoutNodes(): void; - - fuseConvActivationNodes(): void; // TODO: add generic functions to manipulate the graph } @@ -64,16 +62,17 @@ export declare namespace Graph { } } +// eslint-disable-next-line @typescript-eslint/no-redeclare export interface Graph { - getInputIndices(): ReadonlyArray; - getInputNames(): ReadonlyArray; - getOutputIndices(): ReadonlyArray; - getOutputNames(): ReadonlyArray; - getValues(): ReadonlyArray; - getNodes(): ReadonlyArray; + getInputIndices(): readonly number[]; + getInputNames(): readonly string[]; + getOutputIndices(): readonly number[]; + getOutputNames(): readonly string[]; + getValues(): readonly Graph.Value[]; + getNodes(): readonly Graph.Node[]; } -// tslint:disable-next-line:variable-name +// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-redeclare export const Graph = { /** * construct a graph from a graph protobuf type @@ -157,27 +156,27 @@ class GraphImpl implements Graph, Graph.Transformer { this.checkIsAcyclic(); } - getInputIndices(): ReadonlyArray { + getInputIndices(): readonly number[] { return this._allInputIndices; } - getInputNames(): ReadonlyArray { + getInputNames(): readonly string[] { return this._allInputNames; } - getOutputIndices(): ReadonlyArray { + getOutputIndices(): readonly number[] { return this._allOutputIndices; } - getOutputNames(): ReadonlyArray { + getOutputNames(): readonly string[] { return this._allOutputNames; } - getValues(): ReadonlyArray { + getValues(): readonly Graph.Value[] { return this._allData; } - getNodes(): ReadonlyArray { + getNodes(): readonly Graph.Node[] { return this._nodes; } @@ -307,10 +306,10 @@ class GraphImpl implements Graph, Graph.Transformer { // operator and ignore the node from the graph if (nodeProto.opType === 'Constant') { if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) { - throw new Error(`missing attributes or missing tensor value in attributes for this Constant operator`); + throw new Error('missing attributes or missing tensor value in attributes for this Constant operator'); } if (!nodeProto.output || nodeProto.output.length !== 1) { - throw new Error(`missing output or incorrect number of outputs for this Constant operator`); + throw new Error('missing output or incorrect number of outputs for this Constant operator'); } node.outputs.pop(); node.executeNode = false; @@ -372,12 +371,12 @@ class GraphImpl implements Graph, Graph.Transformer { if (valueType !== ortFbs.TypeInfoValue.tensor_type) { throw new Error('Unexpected value type for the nodeArg.'); } - const valueInfo = graph.nodeArgs(j)?.type()?.value(new ortFbs.TensorTypeAndShape()); - const type = ProtoUtil.tensorDataTypeFromProto(valueInfo?.elemType()!); - const shape = valueInfo?.shape(); + const valueInfo = graph.nodeArgs(j)!.type()!.value(new ortFbs.TensorTypeAndShape())!; + const type = ProtoUtil.tensorDataTypeFromProto(valueInfo.elemType()); + const shape = valueInfo.shape()!; const dims = []; - for (let k = 0; k < shape?.dimLength()!; k++) { - dims.push(LongUtil.longToNumber(shape?.dim(k)?.value()?.dimValue()!)); + for (let k = 0; k < shape.dimLength()!; k++) { + dims.push(LongUtil.longToNumber(shape.dim(k)!.value()!.dimValue()!)); } value.type = {shape: {dims}, tensorType: type}; const currentIndex = this._allData.push(value) - 1; @@ -475,10 +474,10 @@ class GraphImpl implements Graph, Graph.Transformer { // operator and ignore the node from the graph if (nodeProto.opType() === 'Constant') { if (nodeProto.attributesLength() !== 1 || !nodeProto.attributes(0)!.t()) { - throw new Error(`missing attributes or missing tensor value in attributes for this Constant operator`); + throw new Error('missing attributes or missing tensor value in attributes for this Constant operator'); } if (nodeProto.outputsLength() !== 1) { - throw new Error(`missing output or incorrect number of outputs for this Constant operator`); + throw new Error('missing output or incorrect number of outputs for this Constant operator'); } node.outputs.pop(); node.executeNode = false; @@ -492,13 +491,13 @@ class GraphImpl implements Graph, Graph.Transformer { // scan node's inputs for (let i = 0; i < this._nodes.length; i++) { const node = this._nodes[i]; - const nodeProto = graph.nodes(i); + const nodeProto = graph.nodes(i)!; - if (nodeProto?.inputsLength() === 0) { + if (nodeProto.inputsLength() === 0) { throw new Error(`missing input for node: ${nodeProto.name}`); } - for (let j = 0; j < nodeProto?.inputsLength()!; j++) { - const input = nodeProto?.inputs(j)!; + for (let j = 0; j < nodeProto.inputsLength()!; j++) { + const input = nodeProto.inputs(j)!; const dataIndex = dataIndices.get(input); if (typeof dataIndex === 'undefined') { throw new Error(`unrecognized input '${input}' for node: ${nodeProto!.name()}`); @@ -537,15 +536,15 @@ class GraphImpl implements Graph, Graph.Transformer { this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => { const data = this._allData[outgoingEdgeIndex]; if (typeof data.tensor !== 'undefined') { - throw new Error(`node outputs should not be initialized`); + throw new Error('node outputs should not be initialized'); } if (data._from !== nodeIndex) { - throw new Error(`from property of the Value object doesn't match index of Node being processed`); + throw new Error('from property of the Value object doesn\'t match index of Node being processed'); } data._to.forEach((downstreamNodeIndex) => { // back edge found - cyclic if (nodesState[downstreamNodeIndex] === 'gray') { - throw new Error(`model graph is cyclic`); + throw new Error('model graph is cyclic'); } // tree edge found - continue processing by adding it to stack else if (nodesState[downstreamNodeIndex] === 'white') { @@ -561,7 +560,6 @@ class GraphImpl implements Graph, Graph.Transformer { // apply common transform this.removeAllIdentityNodes(); this.removeAllDropoutNodes(); - this.fuseConvActivationNodes(); // apply initializer specific transform if (graphInitializer) { @@ -660,12 +658,12 @@ class GraphImpl implements Graph, Graph.Transformer { private deleteNode(nodeIndex: number) { const node = this._nodes[nodeIndex]; if (node.inputs.length > 1) { - throw new Error(`Node deletion with multiple inputs is not supported. `); + throw new Error('Node deletion with multiple inputs is not supported. '); } if (node.outputs.length > 1) { for (let i = 1; i < node.outputs.length; i++) { if (this._allData[node.outputs[i]].to.length > 0) { - throw new Error(`Node deletion with more than one output connected to other nodes is not supported. `); + throw new Error('Node deletion with more than one output connected to other nodes is not supported. '); } } } @@ -680,7 +678,7 @@ class GraphImpl implements Graph, Graph.Transformer { const delIndex = this._allData[inputValueIndex].to.indexOf(nodeIndex); // should not happen if (delIndex === -1) { - throw new Error(`The Value object doesn't have the current Node in it's 'to' property `); + throw new Error('The Value object doesn\'t have the current Node in it\'s \'to\' property '); } this._allData[inputValueIndex].to.splice(delIndex, 1); @@ -699,7 +697,7 @@ class GraphImpl implements Graph, Graph.Transformer { const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex); // should not happen if (replaceIndex === -1) { - throw new Error(`The Node object doesn't have the output Value in it's 'inputs' property `); + throw new Error('The Node object doesn\'t have the output Value in it\'s \'inputs\' property '); } this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex; this._allData[inputValueIndex].to.push(nodeIndex); @@ -714,14 +712,14 @@ class GraphImpl implements Graph, Graph.Transformer { if (node.opType === 'Dropout') { // the node should have exactly 1 input and 1 or 2 outputs if (node.inputs.length !== 1) { - throw new Error(`Dropout nodes should only contain one input. `); + throw new Error('Dropout nodes should only contain one input. '); } if (node.outputs.length !== 1 && node.outputs.length !== 2) { - throw new Error(`Dropout nodes should contain either 1 or 2 output(s)`); + throw new Error('Dropout nodes should contain either 1 or 2 output(s)'); } // the second output should not be referenced by any other node if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) { - throw new Error(`Dropout nodes's second output should not be referenced by other nodes`); + throw new Error('Dropout nodes\'s second output should not be referenced by other nodes'); } this.deleteNode(nodeIndex); } @@ -762,4 +760,4 @@ class GraphImpl implements Graph, Graph.Transformer { } } } -} +} \ No newline at end of file diff --git a/js/web/lib/onnxjs/ops/conv.ts b/js/web/lib/onnxjs/ops/conv.ts index f0ea235b3fd97..b146ad69ede5d 100644 --- a/js/web/lib/onnxjs/ops/conv.ts +++ b/js/web/lib/onnxjs/ops/conv.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. +// Licensed under the MIT License. import {Attribute} from '../attribute'; import {InferenceHandler} from '../backend'; diff --git a/js/web/lib/onnxjs/ops/matmul.ts b/js/web/lib/onnxjs/ops/matmul.ts index 91ab027a876c9..0ad0aeb094886 100644 --- a/js/web/lib/onnxjs/ops/matmul.ts +++ b/js/web/lib/onnxjs/ops/matmul.ts @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. +// Licensed under the MIT License. import {Attribute} from '../attribute'; import {InferenceHandler} from '../backend'; From 6425315d8b3a5128cdc36d924f46dcf7cfedd2b0 Mon Sep 17 00:00:00 2001 From: xueyun zhu Date: Thu, 6 May 2021 17:17:17 -0700 Subject: [PATCH 3/4] fix file name in kebab case --- js/web/lib/onnxjs/backends/webgl/ops/conv.ts | 2 +- .../onnxjs/backends/webgl/ops/{fuse_utils.ts => fuse-utils.ts} | 0 js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename js/web/lib/onnxjs/backends/webgl/ops/{fuse_utils.ts => fuse-utils.ts} (100%) diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts index f2a5359291395..83b7b7bae375e 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts @@ -12,7 +12,7 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t import {WebGLContext} from '../webgl-context'; import {WebGLConvPacked} from './conv-pack'; -import {getActicationSnippet} from './fuse_utils'; +import {getActicationSnippet} from './fuse-utils'; export class WebGLConv extends Conv { unpackedGroupedConvImpl: WebGLUnpackedGroupedConv; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts similarity index 100% rename from js/web/lib/onnxjs/backends/webgl/ops/fuse_utils.ts rename to js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index cef9ff6bc5d7c..a77d06537845d 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -6,7 +6,7 @@ import {Tensor} from '../../../tensor'; import {BroadcastUtil} from '../../../util'; import {WebGLInferenceHandler} from '../inference-handler'; import {ProgramInfo, RunData, WebGLOperator} from '../types'; -import {getActicationSnippet} from './fuse_utils'; +import {getActicationSnippet} from './fuse-utils'; export class WebGLMatMulPacked extends MatMul implements WebGLOperator { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { From 6454dc8f87999be72a2dfa0d373b404f818ffb37 Mon Sep 17 00:00:00 2001 From: xueyun zhu Date: Fri, 7 May 2021 15:12:50 -0700 Subject: [PATCH 4/4] fix lint error --- js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts | 2 +- js/web/lib/onnxjs/backends/webgl/ops/conv.ts | 2 +- js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts index d12df31323a99..8bc69a9ee54eb 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts @@ -37,7 +37,7 @@ export class WebGLConvPacked extends Conv { const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides); const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); const matmul = new WebGLMatMulPacked(); - if (!!this.activation) { + if (this.activation) { const attributes = new Attribute(undefined); attributes.set('__internal_activation', 'string', (this.activation)); matmul.initialize(attributes); diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts index 83b7b7bae375e..9012b07a68f65 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts @@ -68,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { const hasBias = inputs.length > 2; - const processBias = hasBias ? `value += getBias(output_channel);` : ``; + const processBias = hasBias ? 'value += getBias(output_channel);' : ''; const xShape = inputs[0].dims.slice(); const wShape = inputs[1].dims.slice(); const outputChannelsPerGroup = wShape[0] / this.group; diff --git a/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts index d2708a574e47b..115cae72bc31e 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + import {glslRelu, glslSigmoid} from './unary-op'; export function getActicationSnippet(activation: string) {