diff --git a/lib/backends/webgl/ops/conv-pack.ts b/lib/backends/webgl/ops/conv-pack.ts index 414bf64d..1976494b 100644 --- a/lib/backends/webgl/ops/conv-pack.ts +++ b/lib/backends/webgl/ops/conv-pack.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. +import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; import {Conv} from '../../../ops/conv'; import {Tensor} from '../../../tensor'; @@ -36,6 +37,11 @@ export class WebGLConvPacked extends Conv { const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides); const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides); const matmul = new WebGLMatMulPacked(); + if (!!this.activation) { + const attributes = new Attribute(undefined); + attributes.set('__internal_activation', 'string', (this.activation)); + matmul.initialize(attributes); + } const reshape = new WebGLReshapePacked(); // shape for kernel reshape const shape = diff --git a/lib/backends/webgl/ops/conv.ts b/lib/backends/webgl/ops/conv.ts index e1e78b3e..5600b7c9 100644 --- a/lib/backends/webgl/ops/conv.ts +++ b/lib/backends/webgl/ops/conv.ts @@ -3,7 +3,7 @@ import {Attribute} from '../../../attribute'; import {Logger} from '../../../instrument'; -import {Conv, getActicationSnippet} from '../../../ops/conv'; +import {Conv} from '../../../ops/conv'; import {Tensor} from '../../../tensor'; import {PoolConvUtil} from '../../../util'; import {getGlsl} from '../glsl-source'; @@ -12,6 +12,7 @@ import {Artifact, ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../t import {WebGLContext} from '../webgl-context'; import {WebGLConvPacked} from './conv-pack'; +import {getActicationSnippet} from './fuse_utils'; export class WebGLConv extends Conv { unpackedGroupedConvImpl: WebGLUnpackedGroupedConv; diff --git a/lib/backends/webgl/ops/fuse_utils.ts b/lib/backends/webgl/ops/fuse_utils.ts new file mode 100644 index 00000000..40eb6273 --- /dev/null +++ b/lib/backends/webgl/ops/fuse_utils.ts @@ -0,0 +1,23 @@ +import {glslRelu, glslSigmoid} from './unary-op'; + +export function getActicationSnippet(activation: string) { + let activationFunction = ''; + let activationName = ''; + switch (activation) { + case 'Relu': + activationName = glslRelu().name; + activationFunction = glslRelu().body; + break; + case 'Sigmoid': + activationName = glslSigmoid().name; + activationFunction = glslSigmoid().body; + break; + default: + activationName = ''; + activationFunction = ''; + } + const applyActivation = activation ? ` + value = ${activationName}(value);` : + ''; + return {activationFunction, applyActivation}; +} diff --git a/lib/backends/webgl/ops/matmul-pack.ts b/lib/backends/webgl/ops/matmul-pack.ts index 2f9e714e..280b2364 100644 --- a/lib/backends/webgl/ops/matmul-pack.ts +++ b/lib/backends/webgl/ops/matmul-pack.ts @@ -6,6 +6,7 @@ import {Tensor} from '../../../tensor'; import {BroadcastUtil} from '../../../util'; import {WebGLInferenceHandler} from '../inference-handler'; import {ProgramInfo, RunData, WebGLOperator} from '../types'; +import {getActicationSnippet} from './fuse_utils'; export class WebGLMatMulPacked extends MatMul implements WebGLOperator { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { @@ -25,8 +26,11 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { const aRank = aShape.length; const bRank = bShape.length; const sharedDim = aShape[aShape.length - 1]; + + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); // TODO:fix broadcasting const shaderSource = ` + ${activationFunction} vec4 process(int indices[${rank}]) { int a[${aRank}]; int b[${bRank}]; @@ -41,6 +45,7 @@ export class WebGLMatMulPacked extends MatMul implements WebGLOperator { value += ${getA(aRank)}.ggaa * ${getB(bRank)}.baba; } ${processBias} + ${applyActivation} return value; }`; return { diff --git a/lib/ops/conv.ts b/lib/ops/conv.ts index e8d78b7a..f0ea235b 100644 --- a/lib/ops/conv.ts +++ b/lib/ops/conv.ts @@ -3,7 +3,6 @@ import {Attribute} from '../attribute'; import {InferenceHandler} from '../backend'; -import {glslRelu, glslSigmoid} from '../backends/webgl/ops/unary-op'; import {Operator} from '../operators'; import {Tensor} from '../tensor'; @@ -92,25 +91,3 @@ export abstract class Conv implements Operator { protected strides: number[]; protected activation: string; } - -export function getActicationSnippet(activation: string) { - let activationFunction = ''; - let activationName = ''; - switch (activation) { - case 'Relu': - activationName = glslRelu().name; - activationFunction = glslRelu().body; - break; - case 'Sigmoid': - activationName = glslSigmoid().name; - activationFunction = glslSigmoid().body; - break; - default: - activationName = ''; - activationFunction = ''; - } - const applyActivation = activation ? ` - value = ${activationName}(value);` : - ''; - return {activationFunction, applyActivation}; -} diff --git a/lib/ops/matmul.ts b/lib/ops/matmul.ts index eae672f9..91ab027a 100644 --- a/lib/ops/matmul.ts +++ b/lib/ops/matmul.ts @@ -9,7 +9,9 @@ import {Tensor} from '../tensor'; export abstract class MatMul implements Operator { abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise; - initialize(attributes: Attribute): void {} + initialize(attributes: Attribute): void { + this.activation = attributes.getString('__internal_activation', ''); + } checkInputs(inputs: Tensor[]): boolean { if (!inputs || inputs.length !== 2) { @@ -38,4 +40,5 @@ export abstract class MatMul implements Operator { return true; } + protected activation: string; } diff --git a/test/data/super_resolution/super-resolution-10.onnx b/test/data/super_resolution/super-resolution-10.onnx new file mode 100644 index 00000000..53283c05 Binary files /dev/null and b/test/data/super_resolution/super-resolution-10.onnx differ diff --git a/test/data/super_resolution/test_data_set_0/input_0.pb b/test/data/super_resolution/test_data_set_0/input_0.pb new file mode 100644 index 00000000..8d3724a2 Binary files /dev/null and b/test/data/super_resolution/test_data_set_0/input_0.pb differ diff --git a/test/data/teams_model/msra_190729.onnx b/test/data/teams_model/msra_190729.onnx new file mode 100644 index 00000000..2b5ccc97 Binary files /dev/null and b/test/data/teams_model/msra_190729.onnx differ diff --git a/test/data/teams_model/test_data_set_0/input_0.pb b/test/data/teams_model/test_data_set_0/input_0.pb new file mode 100644 index 00000000..b14ff347 Binary files /dev/null and b/test/data/teams_model/test_data_set_0/input_0.pb differ diff --git a/test/data/teams_model/test_data_set_0/output_0.pb b/test/data/teams_model/test_data_set_0/output_0.pb new file mode 100644 index 00000000..019ff65a Binary files /dev/null and b/test/data/teams_model/test_data_set_0/output_0.pb differ