Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Fix an error caused by glsl mod() function #294

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
7 changes: 3 additions & 4 deletions lib/backends/webgl/ops/im2col-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@ export class WebGLIm2ColPacked implements WebGLOperator {

if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) {
offsetY = int(blockIndex / (${this.convOutputShape[rank - 1]})) * ${this.strides[0]} - ${this.pads[1]};
d0 = offsetY + ${this.dilations[0]} * (int(mod(float(pos), ${kernelSize}.)) / ${wshape[2]} );
d0 = offsetY + ${this.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]});

if(d0 < ${xshape[rowDim]} && d0 >= 0) {
offsetX = int(mod(float(blockIndex), ${this.convOutputShape[rank - 1]}.) * ${this.strides[1]}. - ${
this.pads[0]}.);
d1 = offsetX + ${this.dilations[1]} * (int(mod(mod(float(pos), ${kernelSize}.), ${wshape[2]}.)));
offsetX = imod(blockIndex, ${this.convOutputShape[rank - 1]}) * ${this.strides[1]} - ${this.pads[0]};
d1 = offsetX + ${this.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]});

if(d1 < ${xshape[colDim]} && d1 >= 0) {

Expand Down
46 changes: 35 additions & 11 deletions lib/backends/webgl/ops/reshape-packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';
import {ProgramInfo, RunData, TextureData, WebGLOperator} from '../types';
import {TextureLayout} from '../types';

import {unpackFromChannel} from './packing_utils';
Expand All @@ -32,14 +32,18 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
// the same between input shape and output shape, the packed reshape can be
// treated as no-op.
const originInputShape = inputs[0].dims;
const inputShape3D = processDims3D(inputs[0].dims);
this.inputShape3D = processDims3D(inputs[0].dims);
let inputLayout: TextureLayout;
if (originInputShape.length === 3) {
inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true);
} else {
inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true);
if (originInputShape.length !== 3) {
const originalInputLayout = inputLayout;
// if originShape is not a 3D shape, create texture layout from the processed shape.
inputLayout =
handler.createTextureLayoutFromShape(inputShape3D, 4, inputShape3D, {isPacked: true, reverseWH: true});
inputLayout = handler.createTextureLayoutFromShape(
this.inputShape3D, 4, this.inputShape3D, {isPacked: true, reverseWH: true});
// if the processed input shape produces texture layout differnt from the original
// one, the run data has to use the processed (3D) input shape later.
this.needSqueezeInputData =
(inputLayout.height !== originalInputLayout.height) || (inputLayout.width !== originalInputLayout.width);
}

this.outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData);
Expand Down Expand Up @@ -86,9 +90,10 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
const glsl = getGlsl(handler.session.backend.glContext.version);

const shaderSource = `
${getReshapedInputCoords(inputShape3D)}
${getReshapedInputCoords(this.inputShape3D)}
${getFlattenedIndexFrom3D(squeezedOutputShape)}
${unpackFromChannel()}

void main() {
ivec3 rc = getOutputCoords();

Expand All @@ -99,7 +104,6 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
int cols = ${squeezedOutputShape[1]};

${mainLoop}

${glsl.output} = result;
}
`;
Expand All @@ -115,8 +119,26 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
};
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs =
[handler.getOrCreateTextureData(inputs[0], handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false))];
let inputTDs: [TextureData];
const originalInputLayout = handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false);
const originalInputTD = handler.getOrCreateTextureData(inputs[0], originalInputLayout, false);

if (this.needSqueezeInputData) {
duli2012 marked this conversation as resolved.
Show resolved Hide resolved
const squeezedInputLayout: TextureLayout = {
channels: 1,
height: originalInputLayout.height,
width: originalInputLayout.width,
shape: this.inputShape3D,
strides: ShapeUtil.computeStrides(this.inputShape3D),
unpackedShape: this.inputShape3D,
};
const squeezedInputTD =
handler.createSharedTextureData(squeezedInputLayout, inputs[0].type, originalInputTD.texture);
inputTDs = [squeezedInputTD];

} else {
inputTDs = [originalInputTD];
}
let outputLayout = this.originalOutputLayout;
if (outputLayout === undefined) {
const originInputShape = inputs[0].dims;
Expand All @@ -133,6 +155,8 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
}
protected outputShape: ReadonlyArray<number>;
private originalOutputLayout: TextureLayout;
private inputShape3D: [number, number, number];
private needSqueezeInputData = false;
}

function processDims3D(shape: readonly number[]|ReadonlyArray<number>|Tensor.IntegerType): [number, number, number] {
Expand Down
5 changes: 5 additions & 0 deletions test/unittests/backends/webgl/test_reshape_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,10 @@ function getTestData(): TestData[] {
inputShape: [2, 2, 2, 4],
outputShape: [2, 1, 4, 4],
},
{
elementCount: 18432,
inputShape: [512, 36, 1, 1],
outputShape: [512, 36],
},
];
}