diff --git a/lib/backends/webgl/ops/conv.ts b/lib/backends/webgl/ops/conv.ts index e555c08a..e1e78b3e 100644 --- a/lib/backends/webgl/ops/conv.ts +++ b/lib/backends/webgl/ops/conv.ts @@ -3,7 +3,7 @@ import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; -import {Conv} from '../../../ops/conv'; +import {Conv, getActicationSnippet} from '../../../ops/conv'; import {Tensor} from '../../../tensor'; import {PoolConvUtil} from '../../../util'; import {getGlsl} from '../glsl-source'; @@ -12,7 +12,6 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t import {WebGLContext} from '../webgl-context'; import {WebGLConvPacked} from './conv-pack'; -import {glslRelu, glslSigmoid} from './unary-op'; export class WebGLConv extends Conv { unpackedGroupedConvImpl: WebGLUnpackedGroupedConv; @@ -146,28 +145,6 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { } } -function getActicationSnippet(activation: string) { - let activationFunction = ''; - let activationName = ''; - switch (activation) { - case 'Relu': - activationName = glslRelu().name; - activationFunction = glslRelu().body; - break; - case 'Sigmoid': - activationName = glslSigmoid().name; - activationFunction = glslSigmoid().body; - break; - default: - activationName = ''; - activationFunction = ''; - } - const applyActivation = activation ? ` - value = ${activationName}(value);` : - ''; - return {activationFunction, applyActivation}; -} - export class WebGLUnpackedConv extends Conv { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { const programManager = inferenceHandler.session.programManager; @@ -242,7 +219,6 @@ export class WebGLUnpackedConv extends Conv { let blend = false; for (let k = 0; k < sharedDim; k += sharedDimReadSize) { Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`); - if (k === sharedDimReadSize) { blend = true; gl.enable(gl.BLEND); @@ -348,7 +324,7 @@ export class WebGLUnpackedConv extends Conv { const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape); const initValue = (inputs.length < 3) ? '0.0' : '_B(b)'; const sharedDim = im2colLayout.shape[3]; - const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported; + const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported && !this.activation; const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ? this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) : sharedDim; diff --git a/lib/ops/conv.ts b/lib/ops/conv.ts index f0ea235b..e8d78b7a 100644 --- a/lib/ops/conv.ts +++ b/lib/ops/conv.ts @@ -3,6 +3,7 @@ import {Attribute} from '../attribute'; import {InferenceHandler} from '../backend'; +import {glslRelu, glslSigmoid} from '../backends/webgl/ops/unary-op'; import {Operator} from '../operators'; import {Tensor} from '../tensor'; @@ -91,3 +92,25 @@ export abstract class Conv implements Operator { protected strides: number[]; protected activation: string; } + +export function getActicationSnippet(activation: string) { + let activationFunction = ''; + let activationName = ''; + switch (activation) { + case 'Relu': + activationName = glslRelu().name; + activationFunction = glslRelu().body; + break; + case 'Sigmoid': + activationName = glslSigmoid().name; + activationFunction = glslSigmoid().body; + break; + default: + activationName = ''; + activationFunction = ''; + } + const applyActivation = activation ? ` + value = ${activationName}(value);` : + ''; + return {activationFunction, applyActivation}; +}