From abbc3d9e327c40d8cedbe2c54f68821de6e67639 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 12 Jan 2024 13:44:28 -0800 Subject: [PATCH] fix resize for fp16 (#19110) resize for fp16 has 2 issues: scales are always f32 and roi can be f32 or f16. scales: this is fixed. roi this is fixed for the case where roi is not passed as optional input with f16. To fix this it requires a much larger change and I did not want to risk this short before a release. For all practical purpose passing roi as input with f16 should be rare and we can fix it in the near future. --- web/lib/wasm/jsep/webgpu/ops/resize.ts | 27 ++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/web/lib/wasm/jsep/webgpu/ops/resize.ts b/web/lib/wasm/jsep/webgpu/ops/resize.ts index d359580904a7b..f68526acc0e63 100644 --- a/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -70,7 +70,6 @@ const validateInputs = const rank = inputs[0].dims.length; if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) { inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value)); - } else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') { throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize'); } @@ -110,20 +109,20 @@ const validateInputs = const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode, dType: string): string => - `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: ${dType}, lengthResized: u32, - lengthOriginal: u32, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + + `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: f32, lengthResized: u32, + lengthOriginal: u32, roiStart: f32, roiEnd: f32) -> ${dType} { ` + (() => { switch (coordinateTransferMode) { case 'asymmetric': - return `return ${dType}(xResized) / xScale;`; + return `return ${dType}(xResized) / ${dType}(xScale);`; case 'pytorch_half_pixel': return `if (lengthResized > 1) { - return (${dType}(xResized) + 0.5) / xScale - 0.5; + return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5; } else { return 0.0; }`; case 'tf_half_pixel_for_nn': - return `return (${dType}(xResized) + 0.5) / xScale;`; + return `return (${dType}(xResized) + 0.5) / ${dType}(xScale);`; case 'align_corners': return `if (lengthResized == 1) { return 0.0; @@ -138,20 +137,20 @@ const getOriginalCoordinateFromResizedCoordinate = }`; case 'tf_crop_and_resize': return `if (lengthResized > 1) { - return roiStart * ${dType}(lengthOriginal - 1) + - (${dType}(xResized) * (roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) / + return ${dType}(roiStart) * ${dType}(lengthOriginal - 1) + + (${dType}(xResized) * ${dType}(roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) / ${dType}(lengthResized - 1); } else { - return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); + return 0.5 * ${dType}(roiStart + roiEnd) * ${dType}(lengthOriginal - 1); }`; case 'half_pixel_symmetric': - return `const outputWidth = xScale * ${dType}(lengthResized); + return `const outputWidth = ${dType}xScale * ${dType}(lengthResized); const adjustment = ${dType}(lengthResized) / outputWidth; const center = ${dType}(lengthOriginal) / 2; const offset = center * (1 - adjustment); - return offset + ((${dType}(xResized) + 0.5) / xScale) - 0.5;`; + return offset + ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; case 'half_pixel': - return `return ((${dType}(xResized) + 0.5) / xScale) - 0.5;`; + return `return ((${dType}(xResized) + 0.5) / ${dType}(xScale)) - 0.5;`; default: throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); } @@ -663,6 +662,10 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v const scales: number[] = []; const sizes: number[] = []; const roi: number[] = []; + + // Note that scales in resize are always f32. roi can be f32 or f16. + // TODO: Currently this code does not support f16 for roi when passed as optional input. + const opsetVersion = getOpsetVersionFromCustomDataBuffer(context); if (attributes.antialias !== 0) { throw Error('Only default value (0) for Antialias attribute is supported');