Skip to content

Commit

Permalink
fix resize for fp16 (microsoft#19110)
Browse files Browse the repository at this point in the history
resize for fp16 has 2 issues: scales are always f32 and roi can be f32
or f16.
scales:
this is fixed.

roi
this is fixed for the case where roi is not passed as optional input
with f16. To fix this it requires a much larger change and I did not
want to risk this short before a release. For all practical purpose
passing roi as input with f16 should be rare and we can fix it in the
near future.
  • Loading branch information
guschmue authored Jan 12, 2024
1 parent ce08215 commit abbc3d9
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ const validateInputs =
const rank = inputs[0].dims.length;
if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) {
inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value));

} else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') {
throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
}
Expand Down Expand Up @@ -110,20 +109,20 @@ const validateInputs =

const getOriginalCoordinateFromResizedCoordinate =
(coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
`fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: ${dType}, lengthResized: u32,
lengthOriginal: u32, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
`fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32,
lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
return `return ${dType}(xResized) / xScale;`;
return `return ${dType}(xResized) / ${dType}(xScale);`;
case 'pytorch_half_pixel':
return `if (lengthResized > 1) {
return (${dType}(xResized) + 0.5) / xScale - 0.5;
return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5;
} else {
return 0.0;
}`;
case 'tf_half_pixel_for_nn':
return `return (${dType}(xResized) + 0.5) / xScale;`;
return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`;
case 'align_corners':
return `if (lengthResized == 1) {
return 0.0;
Expand All @@ -138,20 +137,20 @@ const getOriginalCoordinateFromResizedCoordinate =
}`;
case 'tf_crop_and_resize':
return `if (lengthResized > 1) {
return roiStart * ${dType}(lengthOriginal - 1) +
(${dType}(xResized) * (roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
return ${dType}(roiStart) * ${dType}(lengthOriginal - 1) +
(${dType}(xResized) * ${dType}(roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
${dType}(lengthResized - 1);
} else {
return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
return 0.5 * ${dType}(roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
}`;
case 'half_pixel_symmetric':
return `const outputWidth = xScale * ${dType}(lengthResized);
return `const outputWidth = ${dType}xScale * ${dType}(lengthResized);
const adjustment = ${dType}(lengthResized) / outputWidth;
const center = ${dType}(lengthOriginal) / 2;
const offset = center * (1 - adjustment);
return offset + ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
return offset + ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`;
case 'half_pixel':
return `return ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`;
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
Expand Down Expand Up @@ -663,6 +662,10 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v
const scales: number[] = [];
const sizes: number[] = [];
const roi: number[] = [];

// Note that scales in resize are always f32. roi can be f32 or f16.
// TODO: Currently this code does not support f16 for roi when passed as optional input.

const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
if (attributes.antialias !== 0) {
throw Error('Only default value (0) for Antialias attribute is supported');
Expand Down

0 comments on commit abbc3d9

Please sign in to comment.