Skip to content

Commit

Permalink
[js/webgpu] Mitigate floating point accuracy issue in Resize (#18956)
Browse files Browse the repository at this point in the history
### Description
The patch fixes a floating point accuracy issue in Resize by preferring
integer indices and integer arithmetic where possible.

### Motivation and Context
Model test `test_resize_upsample_sizes_nearest_floor_align_corners` was
observed to be failing on certain platforms. The root cause is the
inaccurate floating point evaluation of 21 / 7 (2.999... vs 3), which
results in the wrong input element to be indexed (floor(2.999...) vs
floor(3)).
  • Loading branch information
hujiajie authored Jan 3, 2024
1 parent c5f3952 commit 3b8b914
Showing 1 changed file with 45 additions and 38 deletions.
83 changes: 45 additions & 38 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,41 +110,48 @@ const validateInputs =

const getOriginalCoordinateFromResizedCoordinate =
(coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
`fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType},
lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
`fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: ${dType}, lengthResized: u32,
lengthOriginal: u32, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
return 'return xResized / xScale;';
return `return ${dType}(xResized) / xScale;`;
case 'pytorch_half_pixel':
return 'if (lengthResized > 1) { \
return (xResized + 0.5) / xScale - 0.5; \
} else { \
return 0.0; \
}';
return `if (lengthResized > 1) {
return (${dType}(xResized) + 0.5) / xScale - 0.5;
} else {
return 0.0;
}`;
case 'tf_half_pixel_for_nn':
return 'return (xResized + 0.5) / xScale;';
return `return (${dType}(xResized) + 0.5) / xScale;`;
case 'align_corners':
return 'if (lengthResized == 1) { \
return 0.0; \
} else { \
return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
}';
return `if (lengthResized == 1) {
return 0.0;
} else {
// The whole part and the fractional part are calculated separately due to inaccuracy of floating
// point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
// offset-by-one error later in floor().
let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
let fract =
${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
return whole + fract;
}`;
case 'tf_crop_and_resize':
return `if (lengthResized > 1) { \
return roiStart * (lengthOriginal - 1) + \
(xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
} else { \
return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \
return `if (lengthResized > 1) {
return roiStart * ${dType}(lengthOriginal - 1) +
(${dType}(xResized) * (roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
${dType}(lengthResized - 1);
} else {
return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
}`;
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
].join('\n');
return `const outputWidth = xScale * ${dType}(lengthResized);
const adjustment = ${dType}(lengthResized) / outputWidth;
const center = ${dType}(lengthOriginal) / 2;
const offset = center * (1 - adjustment);
return offset + ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
case 'half_pixel':
return 'return ((xResized + 0.5) / xScale) - 0.5;';
return `return ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
Expand Down Expand Up @@ -254,15 +261,15 @@ const calculateOriginalIndicesFromOutputIndices =
output.type.value}, ${outputShape.length}> {
var original_indices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
var output_index = ${output.indicesGet('output_indices', 'i')};
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
if (scale == 1.0) {
original_indices[i] = output_index;
original_indices[i] = ${output.type.value}(output_index);
} else {
var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
input_shape_i, roi_low, roi_hi);
}
Expand All @@ -276,23 +283,23 @@ const calculateInputIndicesFromOutputIndices =
fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
var input_indices: ${input.type.indices};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
var output_index = ${output.indicesGet('output_indices', 'i')};
var input_index: u32;
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
if (scale == 1.0) {
input_index = u32(output_index);
input_index = output_index;
} else {
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
input_shape_i, roi_low, roi_hi);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) {
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) {
if (original_idx < 0) {
input_index = 0;
} else if (original_idx > (input_shape_i - 1)) {
input_index = u32(input_shape_i) - 1;
} else if (original_idx > ${output.type.value}(input_shape_i - 1)) {
input_index = input_shape_i - 1;
} else {
input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
}
Expand Down Expand Up @@ -391,8 +398,8 @@ const bicubicInterpolation =
fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${
output.type.indices}) -> ${dType} {
var output_index = ${output.indicesGet('output_indices', idx)};
var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]},
${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]},
${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
Expand Down

0 comments on commit 3b8b914

Please sign in to comment.