Skip to content
This repository has been archived by the owner on Nov 16, 2023. It is now read-only.

Fix an error caused by glsl mod() function #294

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
7 changes: 3 additions & 4 deletions lib/backends/webgl/ops/im2col-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@ export class WebGLIm2ColPacked implements WebGLOperator {

if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) {
offsetY = int(blockIndex / (${this.convOutputShape[rank - 1]})) * ${this.strides[0]} - ${this.pads[1]};
d0 = offsetY + ${this.dilations[0]} * (int(mod(float(pos), ${kernelSize}.)) / ${wshape[2]} );
d0 = offsetY + ${this.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]});

if(d0 < ${xshape[rowDim]} && d0 >= 0) {
offsetX = int(mod(float(blockIndex), ${this.convOutputShape[rank - 1]}.) * ${this.strides[1]}. - ${
this.pads[0]}.);
d1 = offsetX + ${this.dilations[1]} * (int(mod(mod(float(pos), ${kernelSize}.), ${wshape[2]}.)));
offsetX = imod(blockIndex, ${this.convOutputShape[rank - 1]}) * ${this.strides[1]} - ${this.pads[0]};
d1 = offsetX + ${this.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]});

if(d1 < ${xshape[colDim]} && d1 >= 0) {

Expand Down
5 changes: 5 additions & 0 deletions test/unittests/backends/webgl/test_reshape_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,10 @@ function getTestData(): TestData[] {
inputShape: [2, 2, 2, 4],
outputShape: [2, 1, 4, 4],
},
{
elementCount: 18432,
inputShape: [512, 36, 1, 1],
outputShape: [512, 36],
},
];
}