Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
adding activation fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed Apr 27, 2021
1 parent 290825d commit d0e2371
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
24 changes: 24 additions & 0 deletions lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t
import {WebGLContext} from '../webgl-context';
import {WebGLConvPacked} from './conv-pack';

import {glslRelu, glslSigmoid} from './unary-op';

export class WebGLConv extends Conv {
unpackedGroupedConvImpl: WebGLUnpackedGroupedConv;
unpackedConvImpl: WebGLUnpackedConv;
Expand Down Expand Up @@ -329,8 +331,27 @@ export class WebGLUnpackedConv extends Conv {
if (inputs.length === 3) {
samplers.push('B');
}

let activationFunction = '';
let activationName = '';
switch (this.activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = this.activation.length === 0 ? '' : `sum = ${activationName}(sum);`;

const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
${activationFunction}
float process(int indices[${rank}]) {
int b[1];
b[0] = indices[1];
Expand All @@ -349,6 +370,7 @@ export class WebGLUnpackedConv extends Conv {
++im2colOffset;
++kernelOffset;
}
${applyActivation}
return sum;
}`;
return {
Expand Down Expand Up @@ -400,4 +422,6 @@ export class WebGLUnpackedConv extends Conv {
protected artifacts: Artifact[];
protected readSize = 8;
protected blockSize = 64;

protected activation: string;
}
26 changes: 26 additions & 0 deletions lib/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export declare namespace Graph {
export interface Transformer {
removeAllIdentityNodes(): void;
removeAllDropoutNodes(): void;

fuseConvActivationNodes(): void;
// TODO: add generic functions to manipulate the graph
}

Expand Down Expand Up @@ -559,6 +561,7 @@ class GraphImpl implements Graph, Graph.Transformer {
// apply common transform
this.removeAllIdentityNodes();
this.removeAllDropoutNodes();
this.fuseConvActivationNodes();

// apply initializer specific transform
if (graphInitializer) {
Expand Down Expand Up @@ -736,4 +739,27 @@ class GraphImpl implements Graph, Graph.Transformer {
nodeIndex++;
}
}

isActivation(n: Node): boolean {
switch (n.opType) {
// TODO: add other activation methods
case 'Relu':
case 'Sigmoid':
return true;
default:
return false;
}
}

fuseConvActivationNodes() {
for (const node of this._nodes) {
if (node.opType === 'Conv') {
const next = this._allData[node.outputs[0]]._to;
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
node.attributes.set('__internal_activation', 'string', (this._nodes[next[0]].opType));
this.deleteNode(next[0]);
}
}
}
}
}

0 comments on commit d0e2371

Please sign in to comment.