diff --git a/lib/backends/webgl/ops/conv.ts b/lib/backends/webgl/ops/conv.ts index c73b908a..aebc7d7b 100644 --- a/lib/backends/webgl/ops/conv.ts +++ b/lib/backends/webgl/ops/conv.ts @@ -10,8 +10,8 @@ import {getGlsl} from '../glsl-source'; import {WebGLInferenceHandler} from '../inference-handler'; import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types'; import {WebGLContext} from '../webgl-context'; -import {WebGLConvPacked} from './conv-pack'; +import {WebGLConvPacked} from './conv-pack'; import {glslRelu, glslSigmoid} from './unary-op'; export class WebGLConv extends Conv { @@ -347,7 +347,7 @@ export class WebGLUnpackedConv extends Conv { activationName = ''; activationFunction = ''; } - const applyActivation = this.activation.length === 0 ? '' : `sum = ${activationName}(sum);`; + const applyActivation = this.activation ? `sum = ${activationName}(sum);` : ''; const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const shaderSource = ` @@ -422,6 +422,4 @@ export class WebGLUnpackedConv extends Conv { protected artifacts: Artifact[]; protected readSize = 8; protected blockSize = 64; - - protected activation: string; } diff --git a/lib/ops/conv.ts b/lib/ops/conv.ts index 666eafa5..f0ea235b 100644 --- a/lib/ops/conv.ts +++ b/lib/ops/conv.ts @@ -17,6 +17,7 @@ export abstract class Conv implements Operator { this.kernelShape = attributes.getInts('kernel_shape', []); this.pads = attributes.getInts('pads', [0, 0, 0, 0]); this.strides = attributes.getInts('strides', [1, 1]); + this.activation = attributes.getString('__internal_activation', ''); } checkInputs(inputs: Tensor[]): boolean { @@ -88,4 +89,5 @@ export abstract class Conv implements Operator { protected kernelShape: number[]; protected pads: number[]; protected strides: number[]; + protected activation: string; }