From a9a4677c35d552579073f2ac0c482dc0b4ec94c4 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Tue, 7 Nov 2023 16:07:29 +0800 Subject: [PATCH] [js/webgpu] Simplify the Resize shader when noScale is true --- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index fed1dbcf51e9b..07cfefb8f191b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -454,6 +454,7 @@ const createResizeProgramInfo = const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${noScale ? '' : ` ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; ${(() => { switch (attributes.mode) { @@ -483,23 +484,22 @@ const createResizeProgramInfo = throw Error('Invalid resize mode'); } })()}; + `} ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - if (${noScale}) { - output[global_idx] = input[global_idx]; - } else { - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; - ${(() => { + ${noScale ? 'output[global_idx] = input[global_idx];' : ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + var inputIndices: ${input.type.indices}; + ${(() => { switch (attributes.mode) { case 'nearest': return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices); - if (checkInputIndices(inputIndices)) { - output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; - } else { - output[global_idx] = ${attributes.extrapolationValue}; - }`; + if (checkInputIndices(inputIndices)) { + output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; + } else { + output[global_idx] = ${attributes.extrapolationValue}; + }`; case 'linear': return 'output[global_idx] = bilinearInterpolation(outputIndices);'; case 'cubic': @@ -508,14 +508,14 @@ const createResizeProgramInfo = throw Error(`Unsupported resize mode: ${attributes.mode}`); } })()}; - } + `} }`; return { name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}` + sizes.length > 0 ? sizes : ''}|${noScale}` }, getShaderSource, getRunData: () => ({