Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/web] adding conv fuse logic #7604

Merged
merged 4 commits into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/conv-pack.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {Attribute} from '../../../attribute';
import {Logger} from '../../../instrument';
import {Conv} from '../../../ops/conv';
import {Tensor} from '../../../tensor';
Expand Down Expand Up @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv {
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
const matmul = new WebGLMatMulPacked();
if (this.activation) {
const attributes = new Attribute(undefined);
attributes.set('__internal_activation', 'string', (this.activation));
matmul.initialize(attributes);
}
const reshape = new WebGLReshapePacked();
// shape for kernel reshape
const shape =
Expand Down
37 changes: 23 additions & 14 deletions js/web/lib/onnxjs/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types';
import {WebGLContext} from '../webgl-context';

import {WebGLConvPacked} from './conv-pack';
import {getActicationSnippet} from './fuse-utils';

export class WebGLConv extends Conv {
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
Expand Down Expand Up @@ -66,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {

createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const hasBias = inputs.length > 2;
const processBias = hasBias ? 'dotProd += getBias(output_channel);' : '';
const processBias = hasBias ? 'value += getBias(output_channel);' : '';
const xShape = inputs[0].dims.slice();
const wShape = inputs[1].dims.slice();
const outputChannelsPerGroup = wShape[0] / this.group;
Expand All @@ -85,18 +87,20 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides);
const glsl = getGlsl(handler.session.backend.glContext.version);

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const shaderSource = `
const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]});
const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]});

${activationFunction}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
int output_channel = coords.y;
ivec2 xRCCorner = coords.zw * strides - pads;
int group_id = output_channel / ${outputChannelsPerGroup};

float dotProd = 0.0;
float value = 0.0;
for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) {
int input_channel = group_id * ${wShape[1]} + wInChannel;
for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) {
Expand All @@ -114,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {

float xVal = getX(batch, input_channel, xWidth, xHeight);
float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
dotProd += xVal*wVal;
value += xVal*wVal;
}
}
}
${processBias}
${glsl.output} = vec4(dotProd, .0, .0, .0);
${applyActivation}
${glsl.output} = vec4(value, .0, .0, .0);
}
`;
return {
Expand Down Expand Up @@ -215,7 +220,6 @@ export class WebGLUnpackedConv extends Conv {
let blend = false;
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);

if (k === sharedDimReadSize) {
blend = true;
gl.enable(gl.BLEND);
Expand Down Expand Up @@ -248,6 +252,7 @@ export class WebGLUnpackedConv extends Conv {
const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4);
const outputLayout = inferenceHandler.createTextureLayoutFromShape(
im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3});

const shaderSource = `
const int XC = ${xshape[1]};
const int XH = ${xshape[2]};
Expand All @@ -263,13 +268,12 @@ export class WebGLUnpackedConv extends Conv {
const int KHKW = KH*KW;
const int XCKHKW = XC * KHKW;
const int outputChannels = 4;

vec4 process(int indices[${rank}]) {
int b = indices[0]; // batch size
int oh = indices[1] * strideH - padH; //output height
int ow = indices[2] * strideW - padW; //output width
int p = indices[3] * outputChannels; //patch
vec4 v = vec4(0.0);
vec4 value = vec4(0.0);
for(int i=0; i < outputChannels; ++i) {
if(p < XCKHKW) {
int patchC = p / KHKW;
Expand All @@ -286,12 +290,12 @@ export class WebGLUnpackedConv extends Conv {
xh2 < XH &&
xw2 >= 0 &&
xw2 < XW) {
v[i] = _X(x);
value[i] = _X(x);
}
}
++p;
}
return v;
return value;
}
`;
return {
Expand Down Expand Up @@ -321,16 +325,20 @@ export class WebGLUnpackedConv extends Conv {
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape);
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
const sharedDim = im2colLayout.shape[3];
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation;
const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ?
this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) :
sharedDim;
const samplers = ['Im2Col', 'K'];
if (inputs.length === 3) {
samplers.push('B');
}

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
${activationFunction}
float process(int indices[${rank}]) {
int b[1];
b[0] = indices[1];
Expand All @@ -341,15 +349,16 @@ export class WebGLUnpackedConv extends Conv {
int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${
im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset;
int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset;
float sum = sharedDimOffset == 0 ? ${initValue} : 0.0;
float value = sharedDimOffset == 0 ? ${initValue} : 0.0;
for (int i = 0; i < ${sharedDimReadSize}; ++i) {
vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height});
vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height});
sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
++im2colOffset;
++kernelOffset;
}
return sum;
${applyActivation}
return value;
}`;
return {
inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],
Expand Down
27 changes: 27 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {glslRelu, glslSigmoid} from './unary-op';

export function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
// TODO: adding other activations that can be fused.
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}
5 changes: 5 additions & 0 deletions js/web/lib/onnxjs/backends/webgl/ops/matmul-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor';
import {BroadcastUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';
import {getActicationSnippet} from './fuse-utils';

export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
Expand All @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
const aRank = aShape.length;
const bRank = bShape.length;
const sharedDim = aShape[aShape.length - 1];

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
// TODO:fix broadcasting
const shaderSource = `
${activationFunction}
vec4 process(int indices[${rank}]) {
int a[${aRank}];
int b[${bRank}];
Expand All @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba;
}
${processBias}
${applyActivation}
return value;
}`;
return {
Expand Down
25 changes: 24 additions & 1 deletion js/web/lib/onnxjs/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -737,4 +737,27 @@ class GraphImpl implements Graph, Graph.Transformer {
nodeIndex++;
}
}
}

isActivation(n: Node): boolean {
switch (n.opType) {
// TODO: add other activation methods
case 'Relu':
case 'Sigmoid':
return true;
default:
return false;
}
}

fuseConvActivationNodes() {
for (const node of this._nodes) {
if (node.opType === 'Conv') {
const next = this._allData[node.outputs[0]]._to;
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType));
this.deleteNode(next[0]);
}
}
}
}
}
2 changes: 2 additions & 0 deletions js/web/lib/onnxjs/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export abstract class Conv implements Operator {
this.kernelShape = attributes.getInts('kernel_shape', []);
this.pads = attributes.getInts('pads', [0, 0, 0, 0]);
this.strides = attributes.getInts('strides', [1, 1]);
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
Expand Down Expand Up @@ -88,4 +89,5 @@ export abstract class Conv implements Operator {
protected kernelShape: number[];
protected pads: number[];
protected strides: number[];
protected activation: string;
}
5 changes: 4 additions & 1 deletion js/web/lib/onnxjs/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {Tensor} from '../tensor';
export abstract class MatMul implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(_attributes: Attribute): void {}
initialize(attributes: Attribute): void {
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 2) {
Expand Down Expand Up @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator {

return true;
}
protected activation: string;
}