Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Commit

Permalink
merge with conv
Browse files Browse the repository at this point in the history
  • Loading branch information
xzhu1900 committed Apr 28, 2021
1 parent f0292df commit 3aef5db
Showing 1 changed file with 38 additions and 28 deletions.
66 changes: 38 additions & 28 deletions lib/backends/webgl/ops/conv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {

createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
const hasBias = inputs.length > 2;
const processBias = hasBias ? `dotProd += getBias(output_channel);` : ``;
const processBias = hasBias ? `value += getBias(output_channel);` : ``;
const xShape = inputs[0].dims.slice();
const wShape = inputs[1].dims.slice();
const outputChannelsPerGroup = wShape[0] / this.group;
Expand All @@ -87,18 +87,20 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
const outputShape = WebGLConv.calcOutputShape(xShape, wShape, this.dilations, this.pads, this.strides);
const glsl = getGlsl(handler.session.backend.glContext.version);

const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const shaderSource = `
const ivec2 strides = ivec2(${this.strides[0]}, ${this.strides[1]});
const ivec2 pads = ivec2(${this.pads[0]}, ${this.pads[1]});
${activationFunction}
void main() {
ivec4 coords = getOutputCoords();
int batch = coords.x;
int output_channel = coords.y;
ivec2 xRCCorner = coords.zw * strides - pads;
int group_id = output_channel / ${outputChannelsPerGroup};
float dotProd = 0.0;
float value = 0.0;
for (int wInChannel = 0; wInChannel < ${wShape[1]}; wInChannel++) {
int input_channel = group_id * ${wShape[1]} + wInChannel;
for (int wHeight = 0; wHeight < ${wShape[2]}; wHeight++) {
Expand All @@ -116,12 +118,13 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
float xVal = getX(batch, input_channel, xWidth, xHeight);
float wVal = getW(output_channel, wInChannel, wWidth, wHeight);
dotProd += xVal*wVal;
value += xVal*wVal;
}
}
}
${processBias}
${glsl.output} = vec4(dotProd, .0, .0, .0);
${applyActivation}
${glsl.output} = vec4(value, .0, .0, .0);
}
`;
return {
Expand All @@ -143,6 +146,28 @@ export class WebGLUnpackedGroupedConv extends Conv implements WebGLOperator {
}
}

function getActicationSnippet(activation: string) {
let activationFunction = '';
let activationName = '';
switch (activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = activation ? `
value = ${activationName}(value);` :
'';
return {activationFunction, applyActivation};
}

export class WebGLUnpackedConv extends Conv {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
const programManager = inferenceHandler.session.programManager;
Expand Down Expand Up @@ -250,6 +275,7 @@ export class WebGLUnpackedConv extends Conv {
const im2colDims = WebGLUnpackedConv.calcIm2ColDims(xshape, kshape, outputShape, 4);
const outputLayout = inferenceHandler.createTextureLayoutFromShape(
im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3});

const shaderSource = `
const int XC = ${xshape[1]};
const int XH = ${xshape[2]};
Expand All @@ -265,13 +291,12 @@ export class WebGLUnpackedConv extends Conv {
const int KHKW = KH*KW;
const int XCKHKW = XC * KHKW;
const int outputChannels = 4;
vec4 process(int indices[${rank}]) {
int b = indices[0]; // batch size
int oh = indices[1] * strideH - padH; //output height
int ow = indices[2] * strideW - padW; //output width
int p = indices[3] * outputChannels; //patch
vec4 v = vec4(0.0);
vec4 value = vec4(0.0);
for(int i=0; i < outputChannels; ++i) {
if(p < XCKHKW) {
int patchC = p / KHKW;
Expand All @@ -288,12 +313,12 @@ export class WebGLUnpackedConv extends Conv {
xh2 < XH &&
xw2 >= 0 &&
xw2 < XW) {
v[i] = _X(x);
value[i] = _X(x);
}
}
++p;
}
return v;
return value;
}
`;
return {
Expand Down Expand Up @@ -332,22 +357,7 @@ export class WebGLUnpackedConv extends Conv {
samplers.push('B');
}

let activationFunction = '';
let activationName = '';
switch (this.activation) {
case 'Relu':
activationName = glslRelu().name;
activationFunction = glslRelu().body;
break;
case 'Sigmoid':
activationName = glslSigmoid().name;
activationFunction = glslSigmoid().body;
break;
default:
activationName = '';
activationFunction = '';
}
const applyActivation = this.activation ? `sum = ${activationName}(sum);` : '';
const {activationFunction, applyActivation} = getActicationSnippet(this.activation);

const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
Expand All @@ -362,16 +372,16 @@ export class WebGLUnpackedConv extends Conv {
int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${
im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset;
int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset;
float sum = sharedDimOffset == 0 ? ${initValue} : 0.0;
float value = sharedDimOffset == 0 ? ${initValue} : 0.0;
for (int i = 0; i < ${sharedDimReadSize}; ++i) {
vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height});
vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height});
sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
value += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
++im2colOffset;
++kernelOffset;
}
${applyActivation}
return sum;
return value;
}`;
return {
inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],
Expand Down

0 comments on commit 3aef5db

Please sign in to comment.