Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
add packed conv fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed May 1, 2021
1 parent ff67621 commit 2927e7d
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 30 deletions.
7 changes: 4 additions & 3 deletions lib/backends/webgl/ops/concat_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
`;

return {
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
outputLayout: handler.createTextureLayoutFromShape(outputShape),
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t, 4, true, t.dims, true)),
outputLayout:
handler.createTextureLayoutFromShape(outputShape, 4, outputShape, {isPacked: true, reverseWH: true}),
samplers,
shaderSource,
hasMain: true,
Expand All @@ -120,7 +121,7 @@ export class WebGLPackedConcat extends Concat implements WebGLOperator {
};
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i], true));
return {
inputTextureDatas: inputTDs,
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
Expand Down
6 changes: 6 additions & 0 deletions lib/backends/webgl/ops/conv-pack.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

import {Attribute} from '../../../attribute';
import {Logger} from '../../../instrument';
import {Conv} from '../../../ops/conv';
import {Tensor} from '../../../tensor';
Expand Down Expand Up @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv {
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
const matmul = new WebGLMatMulPacked();
if (!!this.activation) {
const attributes = new Attribute(undefined);
attributes.set('__internal_activation', 'string', (this.activation));
matmul.initialize(attributes);
}
const reshape = new WebGLReshapePacked();
// shape for kernel reshape
const shape =
Expand Down
3 changes: 2 additions & 1 deletion lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Attribute} from '../../../attribute';
import {Logger} from '../../../instrument';
import {Conv, getActicationSnippet} from '../../../ops/conv';
import {Conv} from '../../../ops/conv';
import {Tensor} from '../../../tensor';
import {PoolConvUtil} from '../../../util';
import {getGlsl} from '../glsl-source';
Expand All @@ -12,6 +12,7 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t
import {WebGLContext} from '../webgl-context';

import {WebGLConvPacked} from './conv-pack';
import {getActicationSnippet} from './fuse_utils';

export class WebGLConv extends Conv {
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
Expand Down
23 changes: 23 additions & 0 deletions lib/backends/webgl/ops/fuse_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import {glslRelu, glslSigmoid} from './unary-op';

export function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}
5 changes: 5 additions & 0 deletions lib/backends/webgl/ops/matmul-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor';
import {BroadcastUtil} from '../../../util';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';
import {getActicationSnippet} from './fuse_utils';

export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
Expand All @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
const aRank = aShape.length;
const bRank = bShape.length;
const sharedDim = aShape[aShape.length - 1];

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);
// TODO:fix broadcasting
const shaderSource = `
${activationFunction}
vec4 process(int indices[${rank}]) {
int a[${aRank}];
int b[${bRank}];
Expand All @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator {
value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba;
}
${processBias}
${applyActivation}
return value;
}`;
return {
Expand Down
23 changes: 0 additions & 23 deletions lib/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {glslRelu, glslSigmoid} from '../backends/webgl/ops/unary-op';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

Expand Down Expand Up @@ -92,25 +91,3 @@ export abstract class Conv implements Operator {
protected strides: number[];
protected activation: string;
}

export function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}
5 changes: 4 additions & 1 deletion lib/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import {Tensor} from '../tensor';
export abstract class MatMul implements Operator {
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;

initialize(attributes: Attribute): void {}
initialize(attributes: Attribute): void {
this.activation = attributes.getString('__internal_activation', '');
}

checkInputs(inputs: Tensor[]): boolean {
if (!inputs || inputs.length !== 2) {
Expand Down Expand Up @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator {

return true;
}
protected activation: string;
}
Binary file not shown.
Binary file not shown.
Binary file added test/data/teams_model/msra_190729.onnx
Binary file not shown.
Binary file added test/data/teams_model/test_data_set_0/input_0.pb
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions test/unittests/backends/webgl/test_concat_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ describe('#UnitTest# - packed concat - Tensor concat', () => {
texture: webglTextureB!
};

webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA);
webglInferenceHandler.setTextureData(inputTensorB.dataId, textureDataB);
webglInferenceHandler.setTextureData(inputTensorA.dataId, textureDataA, true);
webglInferenceHandler.setTextureData(inputTensorB.dataId, textureDataB, true);

// compile shader code
const programInfo =
Expand Down

0 comments on commit 2927e7d

Please sign in to comment.