From 3b8b9147fa4f8f6348e171a257bbc325744301df Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Thu, 4 Jan 2024 06:15:26 +0800 Subject: [PATCH] [js/webgpu] Mitigate floating point accuracy issue in Resize (#18956) ### Description The patch fixes a floating point accuracy issue in Resize by preferring integer indices and integer arithmetic where possible. ### Motivation and Context Model test `test_resize_upsample_sizes_nearest_floor_align_corners` was observed to be failing on certain platforms. The root cause is the inaccurate floating point evaluation of 21 / 7 (2.999... vs 3), which results in the wrong input element to be indexed (floor(2.999...) vs floor(3)). --- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 83 ++++++++++++----------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index bea3e8625b41b..d359580904a7b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -110,41 +110,48 @@ const validateInputs = const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode, dType: string): string => - `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, - lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + + `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: ${dType}, lengthResized: u32, + lengthOriginal: u32, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + (() => { switch (coordinateTransferMode) { case 'asymmetric': - return 'return xResized / xScale;'; + return `return ${dType}(xResized) / xScale;`; case 'pytorch_half_pixel': - return 'if (lengthResized > 1) { \ - return (xResized + 0.5) / xScale - 0.5; \ - } else { \ - return 0.0; \ - }'; + return `if (lengthResized > 1) { + return (${dType}(xResized) + 0.5) / xScale - 0.5; + } else { + return 0.0; + }`; case 'tf_half_pixel_for_nn': - return 'return (xResized + 0.5) / xScale;'; + return `return (${dType}(xResized) + 0.5) / xScale;`; case 'align_corners': - return 'if (lengthResized == 1) { \ - return 0.0; \ - } else { \ - return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ - }'; + return `if (lengthResized == 1) { + return 0.0; + } else { + // The whole part and the fractional part are calculated separately due to inaccuracy of floating + // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an + // offset-by-one error later in floor(). + let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1)); + let fract = + ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1); + return whole + fract; + }`; case 'tf_crop_and_resize': - return `if (lengthResized > 1) { \ - return roiStart * (lengthOriginal - 1) + \ - (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ - } else { \ - return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + return `if (lengthResized > 1) { + return roiStart * ${dType}(lengthOriginal - 1) + + (${dType}(xResized) * (roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) / + ${dType}(lengthResized - 1); + } else { + return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); }`; case 'half_pixel_symmetric': - return [ - 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', - 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', - 'return offset + ((xResized + 0.5) / xScale) - 0.5;' - ].join('\n'); + return `const outputWidth = xScale * ${dType}(lengthResized); + const adjustment = ${dType}(lengthResized) / outputWidth; + const center = ${dType}(lengthOriginal) / 2; + const offset = center * (1 - adjustment); + return offset + ((${dType}(xResized) + 0.5) / xScale) - 0.5;`; case 'half_pixel': - return 'return ((xResized + 0.5) / xScale) - 0.5;'; + return `return ((${dType}(xResized) + 0.5) / xScale) - 0.5;`; default: throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); } @@ -254,15 +261,15 @@ const calculateOriginalIndicesFromOutputIndices = output.type.value}, ${outputShape.length}> { var original_indices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var output_index = ${output.indicesGet('output_indices', 'i')}; var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; if (scale == 1.0) { - original_indices[i] = output_index; + original_indices[i] = ${output.type.value}(output_index); } else { - var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); - var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; + var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)}; original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, input_shape_i, roi_low, roi_hi); } @@ -276,23 +283,23 @@ const calculateInputIndicesFromOutputIndices = fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} { var input_indices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { - var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')}); + var output_index = ${output.indicesGet('output_indices', 'i')}; var input_index: u32; var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)}; if (scale == 1.0) { - input_index = u32(output_index); + input_index = output_index; } else { var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)}; var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)}; - var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)}); - var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)}); + var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}; + var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)}; var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i, input_shape_i, roi_low, roi_hi); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) { + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) { if (original_idx < 0) { input_index = 0; - } else if (original_idx > (input_shape_i - 1)) { - input_index = u32(input_shape_i) - 1; + } else if (original_idx > ${output.type.value}(input_shape_i - 1)) { + input_index = input_shape_i - 1; } else { input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1)); } @@ -391,8 +398,8 @@ const bicubicInterpolation = fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${ output.type.indices}) -> ${dType} { var output_index = ${output.indicesGet('output_indices', idx)}; - var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]}, - ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]}, + ${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx);