From 3aef5dbc3f5a112fe1798dc06dca338d8e5854fe Mon Sep 17 00:00:00 2001 From: xueyun zhu Date: Tue, 27 Apr 2021 16:26:06 -0700 Subject: [PATCH] merge with conv --- lib/backends/webgl/ops/conv.ts | 66 +++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/lib/backends/webgl/ops/conv.ts b/lib/backends/webgl/ops/conv.ts index aebc7d7b..e555c08a 100644 --- a/lib/backends/webgl/ops/conv.ts +++ b/lib/backends/webgl/ops/conv.ts @@ -68,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo { const hasBias = inputs.length > 2; - const processBias = hasBias ? `dotProd += getBias(output_channel);` : ``; + const processBias = hasBias ? `value += getBias(output_channel);` : ``; const xShape = inputs[0].dims.slice(); const wShape = inputs[1].dims.slice(); const outputChannelsPerGroup = wShape[0] / this.group; @@ -87,10 +87,12 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides); const glsl = getGlsl(handler.session.backend.glContext.version); + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); + const shaderSource = ` const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]}); const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]}); - + ${activationFunction} void main() { ivec4 coords = getOutputCoords(); int batch = coords.x; @@ -98,7 +100,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { ivec2 xRCCorner = coords.zw * strides - pads; int group_id = output_channel / ${outputChannelsPerGroup}; - float dotProd = 0.0; + float value = 0.0; for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) { int input_channel = group_id * ${wShape[1]} + wInChannel; for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) { @@ -116,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { float xVal = getX(batch, input_channel, xWidth, xHeight); float wVal = getW(output_channel, wInChannel, wWidth, wHeight); - dotProd += xVal*wVal; + value += xVal*wVal; } } } ${processBias} - ${glsl.output} = vec4(dotProd, .0, .0, .0); + ${applyActivation} + ${glsl.output} = vec4(value, .0, .0, .0); } `; return { @@ -143,6 +146,28 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator { } } +function getActicationSnippet(activation: string) { + let activationFunction = ''; + let activationName = ''; + switch (activation) { + case 'Relu': + activationName = glslRelu().name; + activationFunction = glslRelu().body; + break; + case 'Sigmoid': + activationName = glslSigmoid().name; + activationFunction = glslSigmoid().body; + break; + default: + activationName = ''; + activationFunction = ''; + } + const applyActivation = activation ? ` + value = ${activationName}(value);` : + ''; + return {activationFunction, applyActivation}; +} + export class WebGLUnpackedConv extends Conv { run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] { const programManager = inferenceHandler.session.programManager; @@ -250,6 +275,7 @@ export class WebGLUnpackedConv extends Conv { const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4); const outputLayout = inferenceHandler.createTextureLayoutFromShape( im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3}); + const shaderSource = ` const int XC = ${xshape[1]}; const int XH = ${xshape[2]}; @@ -265,13 +291,12 @@ export class WebGLUnpackedConv extends Conv { const int KHKW = KH*KW; const int XCKHKW = XC * KHKW; const int outputChannels = 4; - vec4 process(int indices[${rank}]) { int b = indices[0]; // batch size int oh = indices[1] * strideH - padH; //output height int ow = indices[2] * strideW - padW; //output width int p = indices[3] * outputChannels; //patch - vec4 v = vec4(0.0); + vec4 value = vec4(0.0); for(int i=0; i < outputChannels; ++i) { if(p < XCKHKW) { int patchC = p / KHKW; @@ -288,12 +313,12 @@ export class WebGLUnpackedConv extends Conv { xh2 < XH && xw2 >= 0 && xw2 < XW) { - v[i] = _X(x); + value[i] = _X(x); } } ++p; } - return v; + return value; } `; return { @@ -332,22 +357,7 @@ export class WebGLUnpackedConv extends Conv { samplers.push('B'); } - let activationFunction = ''; - let activationName = ''; - switch (this.activation) { - case 'Relu': - activationName = glslRelu().name; - activationFunction = glslRelu().body; - break; - case 'Sigmoid': - activationName = glslSigmoid().name; - activationFunction = glslSigmoid().body; - break; - default: - activationName = ''; - activationFunction = ''; - } - const applyActivation = this.activation ? `sum = ${activationName}(sum);` : ''; + const {activationFunction, applyActivation} = getActicationSnippet(this.activation); const glsl = getGlsl(inferenceHandler.session.backend.glContext.version); const shaderSource = ` @@ -362,16 +372,16 @@ export class WebGLUnpackedConv extends Conv { int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${ im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset; int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset; - float sum = sharedDimOffset == 0 ? ${initValue} : 0.0; + float value = sharedDimOffset == 0 ? ${initValue} : 0.0; for (int i = 0; i < ${sharedDimReadSize}; ++i) { vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height}); vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height}); - sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords)); + value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords)); ++im2colOffset; ++kernelOffset; } ${applyActivation} - return sum; + return value; }`; return { inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],