Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
fix conv fuse bug
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed May 3, 2021
1 parent 0fd5d8e commit a2b4912
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
28 changes: 2 additions & 26 deletions lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import {Attribute} from '../../../attribute';
import {Logger} from '../../../instrument';
import {Conv} from '../../../ops/conv';
import {Conv, getActicationSnippet} from '../../../ops/conv';
import {Tensor} from '../../../tensor';
import {PoolConvUtil} from '../../../util';
import {getGlsl} from '../glsl-source';
Expand All @@ -12,7 +12,6 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t
import {WebGLContext} from '../webgl-context';

import {WebGLConvPacked} from './conv-pack';
import {glslRelu, glslSigmoid} from './unary-op';

export class WebGLConv extends Conv {
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
Expand Down Expand Up @@ -146,28 +145,6 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
}
}

function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}

export class WebGLUnpackedConv extends Conv {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
const programManager = inferenceHandler.session.programManager;
Expand Down Expand Up @@ -242,7 +219,6 @@ export class WebGLUnpackedConv extends Conv {
let blend = false;
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);

if (k === sharedDimReadSize) {
blend = true;
gl.enable(gl.BLEND);
Expand Down Expand Up @@ -348,7 +324,7 @@ export class WebGLUnpackedConv extends Conv {
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape);
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
const sharedDim = im2colLayout.shape[3];
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation;
const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ?
this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) :
sharedDim;
Expand Down
23 changes: 23 additions & 0 deletions lib/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import {Attribute} from '../attribute';
import {InferenceHandler} from '../backend';
import {glslRelu, glslSigmoid} from '../backends/webgl/ops/unary-op';
import {Operator} from '../operators';
import {Tensor} from '../tensor';

Expand Down Expand Up @@ -91,3 +92,25 @@ export abstract class Conv implements Operator {
protected strides: number[];
protected activation: string;
}

export function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}

0 comments on commit a2b4912

Please sign in to comment.