From 943e2b5dcb0049a77c2ee265c560870249717a35 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Dec 2024 14:55:30 -0800 Subject: [PATCH 1/2] [js/webgpu] fix Conv2DMatMul shader's out-of-bound read --- .../wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 3ef5c943d5624..dec2f9b6ac899 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -143,7 +143,15 @@ const conv2dCommonSnippet = ( } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`; - const sampleW = `${getWSnippet(innerElementSizeW)}`; + const sampleW = + fitInner && fitBOuter + ? getWSnippet(innerElementSizeW) + : ` + let col = colIn * ${innerElementSizeW}; + if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { + ${getWSnippet(innerElementSizeW)} + } + return ${typeSnippet(innerElementSizeW, dataType)}(0.0);`; const resType = typeSnippet(innerElementSize, dataType); const aType = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); From 22800f6d85667b21eae729c47a959d60ffdde4b1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:34:15 -0800 Subject: [PATCH 2/2] fix - consider channels last --- .../wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index dec2f9b6ac899..9e21a552b8466 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -143,14 +143,20 @@ const conv2dCommonSnippet = ( } return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`; - const sampleW = - fitInner && fitBOuter + const sampleW = isChannelsLast + ? fitInner && fitBOuter ? getWSnippet(innerElementSizeW) : ` let col = colIn * ${innerElementSizeW}; if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ${getWSnippet(innerElementSizeW)} } + return ${typeSnippet(innerElementSizeW, dataType)}(0.0);` + : ` + let col = colIn * ${innerElementSizeW}; + if (row < uniforms.dim_inner && col < uniforms.dim_a_outer) { + ${getWSnippet(innerElementSizeW)} + } return ${typeSnippet(innerElementSizeW, dataType)}(0.0);`; const resType = typeSnippet(innerElementSize, dataType);