Skip to content

Commit

Permalink
[js/web] Fix conv2dMatmul errors due to microsoft#18452
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Nov 23, 2023
1 parent 64dacc2 commit 5e541ed
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ export const createConv2DMatMulProgramInfo =

LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);

const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0];
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;

const tileAOuter = workGroupSize[1] * elementsPerThread[1];
const tileBOuter = workGroupSize[0] * elementsPerThread[0];
Expand All @@ -197,7 +197,8 @@ export const createConv2DMatMulProgramInfo =
const components = isVec4 ? 4 : 1;
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components);
const x =
inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
const inputVariables = [x, w];

Expand Down

0 comments on commit 5e541ed

Please sign in to comment.