diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 7b5b91faad380..50e27aa0da1c4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -312,6 +312,14 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): return true; }`; +const setChannelAndBatchIndices = + (input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string => + input.rank > spacialDims ? ` + ${input.indicesSet('input_indices', channelIdx, 'channel')}; + ${input.indicesSet('input_indices', batchIdx, 'batch')}; +` : + ''; + const bilinearInterpolation = (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], useExtrapolation: boolean, extrapolationValue: number): string => { @@ -323,10 +331,7 @@ const bilinearInterpolation = var input_indices: ${input.type.indices}; ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)}; ${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)}; - if (${inputShape.length} > 2) { - ${input.indicesSet('input_indices', channelIdx, 'channel')}; - ${input.indicesSet('input_indices', batchIdx, 'batch')}; - }; + ${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)} return ${input.getByIndices('input_indices')}; } @@ -344,12 +349,8 @@ const bilinearInterpolation = var col1: u32 = u32(col); var row2: u32 = u32(row + 1); var col2: u32 = u32(col + 1); - var channel: u32 = 0; - var batch: u32 = 0; - if (${inputShape.length > 2}) { - channel = u32(originalIndices[${channelIdx}]); - batch = u32(originalIndices[${batchIdx}]); - } + var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'}; + var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'}; var x11: ${dType} = getInputValue(batch, channel, row1, col1); var x12: ${dType} = getInputValue(batch, channel, row1, col2); var x21: ${dType} = getInputValue(batch, channel, row2, col1); @@ -447,10 +448,7 @@ const trilinearInterpolation = ${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)}; ${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)}; ${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)}; - if (${inputShape.length} > 3) { - ${input.indicesSet('input_indices', channelIdx, 'channel')}; - ${input.indicesSet('input_indices', batchIdx, 'batch')}; - }; + ${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)} return ${input.getByIndices('input_indices')}; } @@ -463,21 +461,18 @@ const trilinearInterpolation = inputShape[heightIdx]} - 1) || width < 0 || width > ${inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; } - depth = max(0, min(row, ${inputShape[depthIdx]} - 1)); - height = max(0, min(row, ${inputShape[heightIdx]} - 1)); - width = max(0, min(col, ${inputShape[widthIdx]} - 1)); + depth = max(0, min(depth, ${inputShape[depthIdx]} - 1)); + height = max(0, min(height, ${inputShape[heightIdx]} - 1)); + width = max(0, min(width, ${inputShape[widthIdx]} - 1)); var depth1: u32 = u32(depth); var height1: u32 = u32(height); var width1: u32 = u32(width); var depth2: u32 = u32(depth + 1); var height2: u32 = u32(height + 1); var width2: u32 = u32(width + 1); - var channel: u32 = 0; - var batch: u32 = 0; - if (${inputShape.length > 2}) { - channel = u32(originalIndices[${channelIdx}]); - batch = u32(originalIndices[${batchIdx}]); - } + var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'}; + var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'}; + var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1); var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2); var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1); @@ -579,7 +574,9 @@ const createResizeProgramInfo = output[global_idx] = ${attributes.extrapolationValue}; }`; case 'linear': - return 'output[global_idx] = bilinearInterpolation(output_indices);'; + return `output[global_idx] = ${ + inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' : + 'trilinearInterpolation'}(output_indices);`; case 'cubic': return 'output[global_idx] = bicubicInterpolation(output_indices);'; default: