From e33f7914c69d76920a66fd771c0198e3b2dd9b72 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Thu, 23 Nov 2023 00:12:07 +0400 Subject: [PATCH] [JS/Web] Resize & BiasSplitGelu fp16 support (#18536) ### Description Resize and BiasSplitGelu fp16 support on WebGPU --- .../wasm/jsep/webgpu/ops/bias-split-gelu.ts | 5 +- web/lib/wasm/jsep/webgpu/ops/resize.ts | 151 +++++++++--------- 2 files changed, 81 insertions(+), 75 deletions(-) diff --git a/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index 14eefc344f3c0..a81a7a8f1df5c 100644 --- a/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo} from '../types'; -import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; import {erfImpl} from './unary-op'; const validateInputs = (inputs: readonly TensorView[]): void => { @@ -35,6 +35,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI const output = outputVariable('output', inputs[0].dataType, outputShape, 4); const outputSize = ShapeUtil.size(outputShape) / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const M_SQRT2 = sqrt(2.0); @@ -42,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl('vec4f')} + ${erfImpl(`vec4<${dataType}>`, dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/web/lib/wasm/jsep/webgpu/ops/resize.ts b/web/lib/wasm/jsep/webgpu/ops/resize.ts index 9869561a36251..973a607f9377e 100644 --- a/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -105,50 +105,51 @@ const validateInputs = } }; -const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string => - 'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\ - lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' + +const getOriginalCoordinateFromResizedCoordinate = + (coordinateTransferMode: CoordinateTransformMode, dType: string): string => + `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType}, + lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` + (() => { - switch (coordinateTransferMode) { - case 'asymmetric': - return 'return xResized / xScale;'; - case 'pytorch_half_pixel': - return 'if (lengthResized > 1) { \ + switch (coordinateTransferMode) { + case 'asymmetric': + return 'return xResized / xScale;'; + case 'pytorch_half_pixel': + return 'if (lengthResized > 1) { \ return (xResized + 0.5) / xScale - 0.5; \ } else { \ return 0.0; \ }'; - case 'tf_half_pixel_for_nn': - return 'return (xResized + 0.5) / xScale;'; - case 'align_corners': - return 'if (lengthResized == 1) { \ + case 'tf_half_pixel_for_nn': + return 'return (xResized + 0.5) / xScale;'; + case 'align_corners': + return 'if (lengthResized == 1) { \ return 0.0; \ } else { \ return xResized * (lengthOriginal - 1) / (lengthResized - 1); \ }'; - case 'tf_crop_and_resize': - return 'if (lengthResized > 1) { \ + case 'tf_crop_and_resize': + return `if (lengthResized > 1) { \ return roiStart * (lengthOriginal - 1) + \ (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \ } else { \ - return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \ - }'; - case 'half_pixel_symmetric': - return [ - 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', - 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', - 'return offset + ((xResized + 0.5) / xScale) - 0.5;' - ].join('\n'); - case 'half_pixel': - return 'return ((xResized + 0.5) / xScale) - 0.5;'; - default: - throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); - } - })() + + return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \ + }`; + case 'half_pixel_symmetric': + return [ + 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;', + 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);', + 'return offset + ((xResized + 0.5) / xScale) - 0.5;' + ].join('\n'); + case 'half_pixel': + return 'return ((xResized + 0.5) / xScale) - 0.5;'; + default: + throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`); + } + })() + '}'; -const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string => - 'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => { +const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string => + `fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => { switch (nearestMode) { case 'round_prefer_ceil': return 'if (fract(xOriginal) == 0.5) { \ @@ -246,20 +247,21 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr const calculateOriginalIndicesFromOutputIndices = (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[]): string => ` - fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array { + fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${ + output.type.value}, ${outputShape.length}> { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); - var originalIndices: array; + const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')}); + const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')}); + var originalIndices: array<${output.type.value}, ${outputShape.length}>; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; if (scales[i] == 1.0) { - originalIndices[i] = f32(outputIndex); + originalIndices[i] = ${output.type.value}(outputIndex); } else { - originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); + originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i], + ${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${ + inputShape.length}]); } } return originalIndices; @@ -271,8 +273,8 @@ const calculateInputIndicesFromOutputIndices = fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} { const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); const outputShape = array(${outputShape.map(i => `${i}u`).join(',')}); - const scales = array(${scales.map(i => `${i}f`).join(',')}); - const roi = array(${roi.map(i => `${i}f`).join(',')}); + const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')}); + const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')}); var inputIndices: ${input.type.indices}; for (var i:u32 = 0; i < ${outputShape.length}; i++) { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; @@ -280,12 +282,13 @@ const calculateInputIndicesFromOutputIndices = if (scales[i] == 1.0) { inputIndex = outputIndex; } else { - var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i], - f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]); - if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) { + var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i], + ${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${ + inputShape.length}]); + if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) { if (original_idx < 0) { inputIndex = 0; - } else if (original_idx > (f32(inputShape[i]) - 1)) { + } else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) { inputIndex = inputShape[i] - 1; } else { inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1)); @@ -316,8 +319,9 @@ const bilinearInterpolation = useExtrapolation: boolean, extrapolationValue: number): string => { const [batchIdx, heightIdx, widthIdx, channelIdx] = inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + const dType = input.type.value; return ` - fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 { + fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { var inputIndices: ${input.type.indices}; inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1)); inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1)); @@ -328,10 +332,10 @@ const bilinearInterpolation = return input[${input.indicesToOffset('inputIndices')}]; } - fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices); - var row:f32 = originalIndices[${heightIdx}]; - var col:f32 = originalIndices[${widthIdx}]; + var row:${dType} = originalIndices[${heightIdx}]; + var col:${dType} = originalIndices[${widthIdx}]; if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ inputShape[widthIdx]} - 1)) { return ${extrapolationValue}; @@ -348,14 +352,14 @@ const bilinearInterpolation = channel = u32(originalIndices[${channelIdx}]); batch = u32(originalIndices[${batchIdx}]); } - var x11: f32 = getInputValue(batch, channel, row1, col1); - var x12: f32 = getInputValue(batch, channel, row1, col2); - var x21: f32 = getInputValue(batch, channel, row2, col1); - var x22: f32 = getInputValue(batch, channel, row2, col2); - var dx1: f32 = row - f32(row1); - var dx2: f32 = f32(row2 ) - row; - var dy1 = col - f32(col1); - var dy2 = f32(col2) - col; + var x11: ${dType} = getInputValue(batch, channel, row1, col1); + var x12: ${dType} = getInputValue(batch, channel, row1, col2); + var x21: ${dType} = getInputValue(batch, channel, row2, col1); + var x22: ${dType} = getInputValue(batch, channel, row2, col2); + var dx1: ${dType} = row - ${dType}(row1); + var dx2: ${dType} = ${dType}(row2) - row; + var dy1 = col - ${dType}(col1); + var dy2 = ${dType}(col2) - col; return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -365,24 +369,24 @@ const bicubicInterpolation = scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; - + const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; return ` fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${ - output.type.indices}) -> f32 { + output.type.indices}) -> ${dType} { var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`}; - var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]}, - f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); - var fractOriginalIdx: f32 = originalIdx - floor(originalIdx); + var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]}, + ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length}); + var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx); var coefs = getCubicInterpolationCoefs(fractOriginalIdx); if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) { return ${extrapolationValue}; } - var data: array = array(0.0, 0.0, 0.0, 0.0); + var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); for (var i: i32 = -1; i < 3; i++) { - var ${direction}: f32 = originalIdx + f32(i); + var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { if (${excludeOutside}) { coefs[i + 1] = 0.0; @@ -405,12 +409,12 @@ const bicubicInterpolation = return ` ${createCubicInterpolationFunction(heightIdx)}; ${createCubicInterpolationFunction(widthIdx)}; - fn getCubicInterpolationCoefs(s: f32) -> array { + fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> { var absS = abs(s); - var coeffs: array = array(0.0, 0.0, 0.0, 0.0); - var oneMinusAbsS: f32 = 1.0 - absS; - var twoMinusAbsS: f32 = 2.0 - absS; - var onePlusAbsS: f32 = 1.0 + absS; + var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0); + var oneMinusAbsS: ${dType} = 1.0 - absS; + var twoMinusAbsS: ${dType} = 2.0 - absS; + var onePlusAbsS: ${dType} = 1.0 + absS; coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${ cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA}; coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1; @@ -420,12 +424,12 @@ const bicubicInterpolation = return coeffs; } - fn cubicInterpolation1D(x: array, coefs: array) -> f32 { - var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3]; + fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} { + var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3]; return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum; } - fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 { + fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} { var inputIndices: ${input.type.indices} = outputIndices; return colCubicInterpolation(inputIndices, outputIndices); } @@ -451,15 +455,16 @@ const createResizeProgramInfo = const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const dataType = input.type.value; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${noScale ? '' : ` - ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; + ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)}; ${(() => { switch (attributes.mode) { case 'nearest': return ` ${checkInputIndices(input, inputShape)}; - ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)}; + ${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)}; ${ calculateInputIndicesFromOutputIndices( input, output, inputShape, outputShape, scales, roi, useExtrapolation)};