From 22e545825f272f484ff5ea2ee571a9a45a50ac13 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 29 Sep 2023 15:02:20 -0700 Subject: [PATCH] fix error caused by conflict resolve --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 3925e1cb4f564..f41d0d058a624 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -32,6 +32,7 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe const conv2dTransposeCommonSnippet = (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSize = 4): string => { + const type = typeSnippet(innerElementSize, 'f32'); const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -89,10 +90,10 @@ const conv2dTransposeCommonSnippet = let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); } if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); } let iXR = i32(xR); let iXC = i32(xC); @@ -105,13 +106,13 @@ const conv2dTransposeCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readASnippet} } - return ${typeSnippet(innerElementSize)}(0.0);` : + return ${type}(0.0);` : ` let col = colIn * ${innerElementSize}; if (row < dimInner && col < dimBOuter) { ${readASnippet} } - return ${typeSnippet(innerElementSize)}(0.0);`; + return ${type}(0.0);`; const sampleW = ` let col = colIn * ${innerElementSize}; @@ -125,21 +126,21 @@ const conv2dTransposeCommonSnippet = let coord = vec4(coordX, coordY, col, rowInner); ${getWSnippet(innerElementSize)} } - return ${typeSnippet(innerElementSize)}(0.0); + return ${type}(0.0); `; const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} - fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleA : sampleW} } - fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} { ${isChannelsLast ? sampleW : sampleA} } - fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) { + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { let col = colIn * ${innerElementSize}; if (row < dimAOuter && col < dimBOuter) { var value = valueInput; @@ -234,10 +235,10 @@ export const createConv2DTransposeMatMulProgramInfo = ${declareFunctions} ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} ${ - isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : - makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, - sequentialAccessByThreads)}` + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}` }; };