From fc8631e2f11d85c84ab9cc711aacb9c589b6f71a Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 28 Nov 2023 13:21:47 +0800 Subject: [PATCH] [js/web] Fix conv2dMatmul errors due to #18452 (#18562) ### Description Currently, all conv2dMatmul with inChannels = 3 and outChannels % 4 = 0 will report compilation errors. Models, which include this kind of shape will be impacted, like mobilenetv2-12, resnet50 . The errors is introduced by #18452 https://github.com/microsoft/onnxruntime/pull/18452/files#diff-8b24ea43aa11b1346c0c9e327f9bce6b37a93bd8f2bf8a6392b2b263972b7ea2R200, which accidentally pass `components` to `x`. But `x`'s components is `innerElementSize` not `components `. And when `innerElementSize` is 3, we should use `1` in current design. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 5 +-- js/web/test/data/ops/conv.jsonc | 32 ++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 22f942a0d9ab4..3638938df7dbe 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -180,7 +180,7 @@ export const createConv2DMatMulProgramInfo = LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; const tileAOuter = workGroupSize[1] * elementsPerThread[1]; const tileBOuter = workGroupSize[0] * elementsPerThread[0]; @@ -197,7 +197,8 @@ export const createConv2DMatMulProgramInfo = const components = isVec4 ? 4 : 1; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); + const x = + inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); const inputVariables = [x, w]; diff --git a/js/web/test/data/ops/conv.jsonc b/js/web/test/data/ops/conv.jsonc index 219e15eb4648f..2e8eaaba191d0 100644 --- a/js/web/test/data/ops/conv.jsonc +++ b/js/web/test/data/ops/conv.jsonc @@ -126,7 +126,7 @@ ] }, { - "name": "conv with bias addition C", + "name": "conv with bias addition C - NHWC", "operator": "Conv", "inputShapeDefinitions": "rankOnly", "opset": { "domain": "", "version": 17 }, @@ -158,6 +158,36 @@ "type": "float32" } ] + }, + { + "name": "inChannel = 3, outChannel = 4", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 10], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 1, 2, 3, 4, 5, 6, 7, 8 + ], + "dims": [4, 3, 2, 2], + "type": "float32" + }, + { + "data": [5, 6, 7, 8], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [360, 334, 271, 323, 909, 963, 1024, 1028, 683, 655, 576, 650, 473, 508, 570, 677], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] } ] },