diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts index 2f9a353935938..8bc69a9ee54eb 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; import {Conv} from '../../../ops/conv'; import {Tensor} from '../../../tensor'; @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv { const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides); const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); const matmul = new WebGLMatMulPacked(); + if (this.activation) { + const attributes = new Attribute(undefined); + attributes.set('__internal_activation', 'string', (this.activation)); + matmul.initialize(attributes); + } const reshape = new WebGLReshapePacked(); // shape for kernel reshape const shape = diff --git a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts index 0fa0d61a34414..9012b07a68f65 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/conv.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/conv.ts @@ -10,7 +10,9 @@ import {getGlsl} from '../glsl-source'; import {WebGLInferenceHandler} from '../inference-handler'; import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types'; import {WebGLContext} from '../webgl-context'; + import {WebGLConvPacked} from './conv-pack'; +import {getActicationSnippet} from './fuse-utils'; export class WebGLConv extends Conv { unpackedGroupedConvImpl: WebGLUnpackedGroupedConv; @@ -66,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { const hasBias = inputs.length > 2; - const processBias = hasBias ? 'dotProd += getBias(output_channel);' : ''; + const processBias = hasBias ? 'value += getBias(output_channel);' : ''; const xShape = inputs[0].dims.slice(); const wShape = inputs[1].dims.slice(); const outputChannelsPerGroup = wShape[0] / this.group; @@ -85,10 +87,12 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides); const glsl = getGlsl(handler.session.backend.glContext.version); + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); + const shaderSource = ` const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]}); const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]}); - + ${activationFunction} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; @@ -96,7 +100,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { ivec2 xRCCorner = coords.zw * strides - pads; int group_id = output_channel / ${outputChannelsPerGroup}; - float dotProd = 0.0; + float value = 0.0; for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) { int input_channel = group_id * ${wShape[1]} + wInChannel; for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) { @@ -114,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { float xVal = getX(batch, input_channel, xWidth, xHeight); float wVal = getW(output_channel, wInChannel, wWidth, wHeight); - dotProd += xVal*wVal; + value += xVal*wVal; } } } ${processBias} - ${glsl.output} = vec4(dotProd, .0, .0, .0); + ${applyActivation} + ${glsl.output} = vec4(value, .0, .0, .0); } `; return { @@ -215,7 +220,6 @@ export class WebGLUnpackedConv extends Conv { let blend = false; for (let k = 0; k < sharedDim; k += sharedDimReadSize) { Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`); - if (k === sharedDimReadSize) { blend = true; gl.enable(gl.BLEND); @@ -248,6 +252,7 @@ export class WebGLUnpackedConv extends Conv { const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4); const outputLayout = inferenceHandler.createTextureLayoutFromShape( im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3}); + const shaderSource = ` const int XC = ${xshape[1]}; const int XH = ${xshape[2]}; @@ -263,13 +268,12 @@ export class WebGLUnpackedConv extends Conv { const int KHKW = KH*KW; const int XCKHKW = XC * KHKW; const int outputChannels = 4; - vec4 process(int indices[${rank}]) { int b = indices[0]; // batch size int oh = indices[1] * strideH - padH; //output height int ow = indices[2] * strideW - padW; //output width int p = indices[3] * outputChannels; //patch - vec4 v = vec4(0.0); + vec4 value = vec4(0.0); for(int i=0; i < outputChannels; ++i) { if(p < XCKHKW) { int patchC = p / KHKW; @@ -286,12 +290,12 @@ export class WebGLUnpackedConv extends Conv { xh2 < XH && xw2 >= 0 && xw2 < XW) { - v[i] = _X(x); + value[i] = _X(x); } } ++p; } - return v; + return value; } `; return { @@ -321,7 +325,7 @@ export class WebGLUnpackedConv extends Conv { const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape); const initValue = (inputs.length < 3) ? '0.0' : '_B(b)'; const sharedDim = im2colLayout.shape[3]; - const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported; + const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation; const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ? this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) : sharedDim; @@ -329,8 +333,12 @@ export class WebGLUnpackedConv extends Conv { if (inputs.length === 3) { samplers.push('B'); } + + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); + const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const shaderSource = ` + ${activationFunction} float process(int indices[${rank}]) { int b[1]; b[0] = indices[1]; @@ -341,15 +349,16 @@ export class WebGLUnpackedConv extends Conv { int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${ im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset; int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset; - float sum = sharedDimOffset == 0 ? ${initValue} : 0.0; + float value = sharedDimOffset == 0 ? ${initValue} : 0.0; for (int i = 0; i < ${sharedDimReadSize}; ++i) { vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height}); vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height}); - sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords)); + value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords)); ++im2colOffset; ++kernelOffset; } - return sum; + ${applyActivation} + return value; }`; return { inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout], diff --git a/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts new file mode 100644 index 0000000000000..115cae72bc31e --- /dev/null +++ b/js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {glslRelu, glslSigmoid} from './unary-op'; + +export function getActicationSnippet(activation: string) { + let activationFunction = ''; + let activationName = ''; + switch (activation) { + case 'Relu': + activationName = glslRelu().name; + activationFunction = glslRelu().body; + break; + case 'Sigmoid': + activationName = glslSigmoid().name; + activationFunction = glslSigmoid().body; + break; + default: + // TODO: adding other activations that can be fused. + activationName = ''; + activationFunction = ''; + } + const applyActivation = activation ? ` + value = ${activationName}(value);` : + ''; + return {activationFunction, applyActivation}; +} diff --git a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts index a508debcef52f..a77d06537845d 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts @@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor'; import {BroadcastUtil} from '../../../util'; import {WebGLInferenceHandler} from '../inference-handler'; import {ProgramInfo, RunData, WebGLOperator} from '../types'; +import {getActicationSnippet} from './fuse-utils'; export class WebGLMatMulPacked extends MatMul implements WebGLOperator { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { const aRank = aShape.length; const bRank = bShape.length; const sharedDim = aShape[aShape.length - 1]; + + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); // TODO:fix broadcasting const shaderSource = ` + ${activationFunction} vec4 process(int indices[${rank}]) { int a[${aRank}]; int b[${bRank}]; @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba; } ${processBias} + ${applyActivation} return value; }`; return { diff --git a/js/web/lib/onnxjs/graph.ts b/js/web/lib/onnxjs/graph.ts index 23dfd803c92ba..af281d104703a 100644 --- a/js/web/lib/onnxjs/graph.ts +++ b/js/web/lib/onnxjs/graph.ts @@ -737,4 +737,27 @@ class GraphImpl implements Graph, Graph.Transformer { nodeIndex++; } } -} + + isActivation(n: Node): boolean { + switch (n.opType) { + // TODO: add other activation methods + case 'Relu': + case 'Sigmoid': + return true; + default: + return false; + } + } + + fuseConvActivationNodes() { + for (const node of this._nodes) { + if (node.opType === 'Conv') { + const next = this._allData[node.outputs[0]]._to; + if (next.length === 1 && this.isActivation(this._nodes[next[0]])) { + node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType)); + this.deleteNode(next[0]); + } + } + } + } +} \ No newline at end of file diff --git a/js/web/lib/onnxjs/ops/conv.ts b/js/web/lib/onnxjs/ops/conv.ts index d96c0b690eda1..b146ad69ede5d 100644 --- a/js/web/lib/onnxjs/ops/conv.ts +++ b/js/web/lib/onnxjs/ops/conv.ts @@ -17,6 +17,7 @@ export abstract class Conv implements Operator { this.kernelShape = attributes.getInts('kernel_shape', []); this.pads = attributes.getInts('pads', [0, 0, 0, 0]); this.strides = attributes.getInts('strides', [1, 1]); + this.activation = attributes.getString('__internal_activation', ''); } checkInputs(inputs: Tensor[]): boolean { @@ -88,4 +89,5 @@ export abstract class Conv implements Operator { protected kernelShape: number[]; protected pads: number[]; protected strides: number[]; + protected activation: string; } diff --git a/js/web/lib/onnxjs/ops/matmul.ts b/js/web/lib/onnxjs/ops/matmul.ts index 567dd6d239078..0ad0aeb094886 100644 --- a/js/web/lib/onnxjs/ops/matmul.ts +++ b/js/web/lib/onnxjs/ops/matmul.ts @@ -9,7 +9,9 @@ import {Tensor} from '../tensor'; export abstract class MatMul implements Operator { abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; - initialize(_attributes: Attribute): void {} + initialize(attributes: Attribute): void { + this.activation = attributes.getString('__internal_activation', ''); + } checkInputs(inputs: Tensor[]): boolean { if (!inputs || inputs.length !== 2) { @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator { return true; } + protected activation: string; }