diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 50e27aa0da1c4..e0a8db9526370 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -321,10 +321,11 @@ const setChannelAndBatchIndices = ''; const bilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], - useExtrapolation: boolean, extrapolationValue: number): string => { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean, + extrapolationValue: number): string => { + const isNchw = true; const [batchIdx, heightIdx, widthIdx, channelIdx] = - inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]); const dType = input.type.value; return ` fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { @@ -339,10 +340,12 @@ const bilinearInterpolation = var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); var row:${dType} = originalIndices[${heightIdx}]; var col:${dType} = originalIndices[${widthIdx}]; - if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ - inputShape[widthIdx]} - 1)) { + ${ + useExtrapolation ? `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${ + inputShape[widthIdx]} - 1))) { return ${extrapolationValue}; - } + }` : + ''}; row = max(0, min(row, ${inputShape[heightIdx]} - 1)); col = max(0, min(col, ${inputShape[widthIdx]} - 1)); var row1: u32 = u32(row); @@ -355,10 +358,18 @@ const bilinearInterpolation = var x12: ${dType} = getInputValue(batch, channel, row1, col2); var x21: ${dType} = getInputValue(batch, channel, row2, col1); var x22: ${dType} = getInputValue(batch, channel, row2, col2); - var dx1: ${dType} = row - ${dType}(row1); - var dx2: ${dType} = ${dType}(row2) - row; - var dy1 = col - ${dType}(col1); - var dy2 = ${dType}(col2) - col; + var dx1: ${dType} = abs(row - ${dType}(row1)); + var dx2: ${dType} = abs(${dType}(row2) - row); + var dy1: ${dType} = abs(col - ${dType}(col1)); + var dy2: ${dType} = abs(${dType}(col2) - col); + if (row1 == row2) { + dx1 = 0.5; + dx2 = 0.5; + } + if (col1 == col2) { + dy1 = 0.5; + dy2 = 0.5; + } return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -367,7 +378,9 @@ const bicubicInterpolation = (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { - const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; + const is2D = inputShape.length === 2; + const isNchw = true; + const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2]; const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; @@ -387,16 +400,18 @@ const bicubicInterpolation = for (var i: i32 = -1; i < 3; i++) { var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { - if (${excludeOutside}) { - coefs[i + 1] = 0.0; - continue; - } else if (${useExtrapolation}) { - return ${extrapolationValue}; - } else { - ${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1)); - } + ${(() => { + if (excludeOutside) { + return `coefs[i + 1] = 0.0; + continue;`; + } else if (useExtrapolation) { + return `return ${extrapolationValue};`; + } else { + return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`; + } + })()}; } - var input_indices_copy: ${input.type.indices} = input_indices; + var input_indices_copy: ${input.type.indices} = input_indices; ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)}; data[i + 1] = ${ idx === heightIdx ? input.getByIndices('input_indices_copy') : @@ -437,10 +452,11 @@ const bicubicInterpolation = }; const trilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], - useExtrapolation: boolean, extrapolationValue: number): string => { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean, + extrapolationValue: number): string => { + const isNchw = true; const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] = - inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (scales[1] === 1.0 ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]); + inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]); const dType = input.type.value; return ` fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} { @@ -457,11 +473,14 @@ const trilinearInterpolation = var depth:${dType} = originalIndices[${depthIdx}]; var height:${dType} = originalIndices[${heightIdx}]; var width:${dType} = originalIndices[${widthIdx}]; - if (${useExtrapolation} && (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${ - inputShape[heightIdx]} - 1) || width < 0 || width > ${inputShape[widthIdx]} - 1)) { - return ${extrapolationValue}; - } - depth = max(0, min(depth, ${inputShape[depthIdx]} - 1)); + ${ + useExtrapolation ? `(depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${ + inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) { + return ${extrapolationValue}; + }` : + ''}; + + depth = max(0, min(depth, ${inputShape[depthIdx]} - 1)); height = max(0, min(height, ${inputShape[heightIdx]} - 1)); width = max(0, min(width, ${inputShape[widthIdx]} - 1)); var depth1: u32 = u32(depth); @@ -481,12 +500,24 @@ const trilinearInterpolation = var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2); var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1); var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2); - var dx1 = depth - ${dType}(depth1); - var dx2 = ${dType}(depth2) - depth; - var dy1 = height - ${dType}(height1); - var dy2 = ${dType}(height2) - height; - var dz1: ${dType} = width - ${dType}(width1); - var dz2: ${dType} = ${dType}(width2) - width; + var dx1: ${dType} = abs(depth - ${dType}(depth1)); + var dx2: ${dType} = abs(${dType}(depth2) - depth); + var dy1: ${dType} = abs(height - ${dType}(height1)); + var dy2: ${dType} = abs(${dType}(height2) - height); + var dz1: ${dType} = abs(width - ${dType}(width1)); + var dz2: ${dType} = abs(${dType}(width2) - width); + if (depth1 == depth2) { + dx1 = 0.5; + dx2 = 0.5; + } + if (height1 == height2) { + dy1 = 0.5; + dy2 = 0.5; + } + if (width1 == width2) { + dx1 = 0.5; + dx2 = 0.5; + } return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 + x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1); }`; @@ -531,23 +562,28 @@ const createResizeProgramInfo = ${(() => { if (inputShape.length === 2 || inputShape.length === 4) { return `${ - bilinearInterpolation( - input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}`; + bilinearInterpolation(input, output, inputShape, useExtrapolation, attributes.extrapolationValue)}`; } else if (inputShape.length === 3 || inputShape.length === 5) { return `${ trilinearInterpolation( - input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}`; + input, output, inputShape, useExtrapolation, attributes.extrapolationValue)}`; } else { - throw Error('Only input dims 2, 3, 4 and 5 are supported in linear mode.'); + throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.'); } })()}; `; case 'cubic': return ` - ${ - bicubicInterpolation( - input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation, - attributes.extrapolationValue, attributes.excludeOutside)}; + ${(() => { + if (inputShape.length === 2 || inputShape.length === 4) { + return `${ + bicubicInterpolation( + input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation, + attributes.extrapolationValue, attributes.excludeOutside)}`; + } else { + throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.'); + } + })()}; `; default: throw Error('Invalid resize mode'); @@ -575,8 +611,8 @@ const createResizeProgramInfo = }`; case 'linear': return `output[global_idx] = ${ - inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' : - 'trilinearInterpolation'}(output_indices);`; + (inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' : + 'trilinearInterpolation'}(output_indices);`; case 'cubic': return 'output[global_idx] = bicubicInterpolation(output_indices);'; default: @@ -590,7 +626,7 @@ const createResizeProgramInfo = name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`, + sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`, inputDependencies: ['rank'] }, getShaderSource,