Skip to content

Commit

Permalink
[js/web] adding conv fuse logic (#7604)
Browse files Browse the repository at this point in the history
* adding conv fuse logic

* fixing merge

* fix file name in kebab case

* fix lint error
  • Loading branch information
xzhu1900 authored May 10, 2021
1 parent 88c95ef commit c5d2809
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 16 deletions.
6 changes: 6 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Attribute} from '../../../attribute';
import {Logger} from '../../../instrument';
import {Conv} from '../../../ops/conv';
import {Tensor} from '../../../tensor';
Expand Down Expand Up @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv {
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
const matmul = new WebGLMatMulPacked();
if (this.activation) {
const attributes = new Attribute(undefined);
attributes.set('__internal_activation', 'string', (this.activation));
matmul.initialize(attributes);
}
const reshape = new WebGLReshapePacked();
// shape for kernel reshape
const shape =
Expand Down
37 changes: 23 additions & 14 deletions js/web/lib/onnxjs/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types';
import {WebGLContext} from '../webgl-context';

import {WebGLConvPacked} from './conv-pack';
import {getActicationSnippet} from './fuse-utils';

export class WebGLConv extends Conv {
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
Expand Down Expand Up @@ -66,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {

createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const hasBias = inputs.length > 2;
const processBias = hasBias ? 'dotProd += getBias(output_channel);' : '';
const processBias = hasBias ? 'value += getBias(output_channel);' : '';
const xShape = inputs[0].dims.slice();
const wShape = inputs[1].dims.slice();
const outputChannelsPerGroup = wShape[0] / this.group;
Expand All @@ -85,18 +87,20 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides);
const glsl = getGlsl(handler.session.backend.glContext.version);

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const shaderSource = `
const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]});
const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]});
${activationFunction}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
int output_channel = coords.y;
ivec2 xRCCorner = coords.zw * strides - pads;
int group_id = output_channel / ${outputChannelsPerGroup};
float dotProd = 0.0;
float value = 0.0;
for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) {
int input_channel = group_id * ${wShape[1]} + wInChannel;
for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) {
Expand All @@ -114,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
float xVal = getX(batch, input_channel, xWidth, xHeight);
float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
dotProd += xVal*wVal;
value += xVal*wVal;
}
}
}
${processBias}
${glsl.output} = vec4(dotProd, .0, .0, .0);
${applyActivation}
${glsl.output} = vec4(value, .0, .0, .0);
}
`;
return {
Expand Down Expand Up @@ -215,7 +220,6 @@ export class WebGLUnpackedConv extends Conv {
let blend = false;
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);

if (k === sharedDimReadSize) {
blend = true;
gl.enable(gl.BLEND);
Expand Down Expand Up @@ -248,6 +252,7 @@ export class WebGLUnpackedConv extends Conv {
const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4);
const outputLayout = inferenceHandler.createTextureLayoutFromShape(
im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3});

const shaderSource = `
const int XC = ${xshape[1]};
const int XH = ${xshape[2]};
Expand All @@ -263,13 +268,12 @@ export class WebGLUnpackedConv extends Conv {
const int KHKW = KH*KW;
const int XCKHKW = XC * KHKW;
const int outputChannels = 4;
vec4 process(int indices[${rank}]) {
int b = indices[0]; // batch size
int oh = indices[1] * strideH - padH; //output height
int ow = indices[2] * strideW - padW; //output width
int p = indices[3] * outputChannels; //patch
vec4 v = vec4(0.0);
vec4 value = vec4(0.0);
for(int i=0; i < outputChannels; ++i) {
if(p < XCKHKW) {
int patchC = p / KHKW;
Expand All @@ -286,12 +290,12 @@ export class WebGLUnpackedConv extends Conv {
xh2 < XH &&
xw2 >= 0 &&
xw2 < XW) {
v[i] = _X(x);
value[i] = _X(x);
}
}
++p;
}
return v;
return value;
}
`;
return {
Expand Down Expand Up @@ -321,16 +325,20 @@ export class WebGLUnpackedConv extends Conv {
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape);
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
const sharedDim = im2colLayout.shape[3];
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation;
const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ?
this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) :
sharedDim;
const samplers = ['Im2Col', 'K'];
if (inputs.length === 3) {
samplers.push('B');
}

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
${activationFunction}
float process(int indices[${rank}]) {
int b[1];
b[0] = indices[1];
Expand All @@ -341,15 +349,16 @@ export class WebGLUnpackedConv extends Conv {
int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${
im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset;
int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset;
float sum = sharedDimOffset == 0 ? ${initValue} : 0.0;
float value = sharedDimOffset == 0 ? ${initValue} : 0.0;
for (int i = 0; i < ${sharedDimReadSize}; ++i) {
vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height});
vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height});
sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
++im2colOffset;
++kernelOffset;
}
return sum;
${applyActivation}
return value;
}`;
return {
inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],
Expand Down
27 changes: 27 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {glslRelu, glslSigmoid} from './unary-op';

export function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
// TODO: adding other activations that can be fused.
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}
5 changes: 5 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor';
import {BroadcastUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';
import {getActicationSnippet} from './fuse-utils';

export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
Expand All @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
const aRank = aShape.length;
const bRank = bShape.length;
const sharedDim = aShape[aShape.length - 1];

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
// TODO:fix broadcasting
const shaderSource = `
${activationFunction}
vec4 process(int indices[${rank}]) {
int a[${aRank}];
int b[${bRank}];
Expand All @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba;
}
${processBias}
${applyActivation}
return value;
}`;
return {
Expand Down
25 changes: 24 additions & 1 deletion js/web/lib/onnxjs/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -737,4 +737,27 @@ class GraphImpl implements Graph, Graph.Transformer {
nodeIndex++;
}
}
}

isActivation(n: Node): boolean {
switch (n.opType) {
// TODO: add other activation methods
case 'Relu':
case 'Sigmoid':
return true;
default:
return false;
}
}

fuseConvActivationNodes() {
for (const node of this._nodes) {
if (node.opType === 'Conv') {
const next = this._allData[node.outputs[0]]._to;
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType));
this.deleteNode(next[0]);
}
}
}
}
}
2 changes: 2 additions & 0 deletions js/web/lib/onnxjs/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export abstract class Conv implements Operator {
this.kernelShape = attributes.getInts('kernel_shape', []);
this.pads = attributes.getInts('pads', [0, 0, 0, 0]);
this.strides = attributes.getInts('strides', [1, 1]);
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
Expand Down Expand Up @@ -88,4 +89,5 @@ export abstract class Conv implements Operator {
protected kernelShape: number[];
protected pads: number[];
protected strides: number[];
protected activation: string;
}
5 changes: 4 additions & 1 deletion js/web/lib/onnxjs/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {Tensor} from '../tensor';
export abstract class MatMul implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(_attributes: Attribute): void {}
initialize(attributes: Attribute): void {
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 2) {
Expand Down Expand Up @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator {

return true;
}
protected activation: string;
}

0 comments on commit c5d2809

Please sign in to comment.