Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/Web] Resize & BiasSplitGelu fp16 support #18536

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {erfImpl} from './unary-op';

const validateInputs = (inputs: readonly TensorView[]): void => {
Expand Down Expand Up @@ -35,14 +35,15 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI
const output = outputVariable('output', inputs[0].dataType, outputShape, 4);

const outputSize = ShapeUtil.size(outputShape) / 4;
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const M_SQRT2 = sqrt(2.0);
const halfChannels = ${inputs[0].dims[2] / 4 / 2}u;
${shaderHelper.declareVariables(input, bias, output)}
${erfImpl('vec4f')}
${erfImpl(`vec4<${dataType}>`, dataType)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
Expand Down
151 changes: 78 additions & 73 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,50 +105,51 @@ const validateInputs =
}
};

const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string =>
'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\
lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' +
const getOriginalCoordinateFromResizedCoordinate =
(coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
`fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType},
lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
return 'return xResized / xScale;';
case 'pytorch_half_pixel':
return 'if (lengthResized > 1) { \
switch (coordinateTransferMode) {
case 'asymmetric':
return 'return xResized / xScale;';
case 'pytorch_half_pixel':
return 'if (lengthResized > 1) { \
return (xResized + 0.5) / xScale - 0.5; \
} else { \
return 0.0; \
}';
case 'tf_half_pixel_for_nn':
return 'return (xResized + 0.5) / xScale;';
case 'align_corners':
return 'if (lengthResized == 1) { \
case 'tf_half_pixel_for_nn':
return 'return (xResized + 0.5) / xScale;';
case 'align_corners':
return 'if (lengthResized == 1) { \
return 0.0; \
} else { \
return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
}';
case 'tf_crop_and_resize':
return 'if (lengthResized > 1) { \
case 'tf_crop_and_resize':
return `if (lengthResized > 1) { \
return roiStart * (lengthOriginal - 1) + \
(xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
} else { \
return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \
}';
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
].join('\n');
case 'half_pixel':
return 'return ((xResized + 0.5) / xScale) - 0.5;';
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
})() +
return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \
}`;
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
].join('\n');
case 'half_pixel':
return 'return ((xResized + 0.5) / xScale) - 0.5;';
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
})() +
'}';

const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string =>
'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => {
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string =>
`fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => {
switch (nearestMode) {
case 'round_prefer_ceil':
return 'if (fract(xOriginal) == 0.5) { \
Expand Down Expand Up @@ -246,20 +247,21 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr
const calculateOriginalIndicesFromOutputIndices =
(output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[],
roi: readonly number[]): string => `
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<f32, ${
outputShape.length}> {
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${
output.type.value}, ${outputShape.length}> {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<f32, ${outputShape.length}>;
const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
if (scales[i] == 1.0) {
originalIndices[i] = f32(outputIndex);
originalIndices[i] = ${output.type.value}(outputIndex);
} else {
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i],
${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${
inputShape.length}]);
}
}
return originalIndices;
Expand All @@ -271,21 +273,22 @@ const calculateInputIndicesFromOutputIndices =
fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')});
const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')});
var inputIndices: ${input.type.indices};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex: u32;
if (scales[i] == 1.0) {
inputIndex = outputIndex;
} else {
var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) {
var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i],
${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${
inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) {
if (original_idx < 0) {
inputIndex = 0;
} else if (original_idx > (f32(inputShape[i]) - 1)) {
} else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) {
inputIndex = inputShape[i] - 1;
} else {
inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1));
Expand Down Expand Up @@ -316,8 +319,9 @@ const bilinearInterpolation =
useExtrapolation: boolean, extrapolationValue: number): string => {
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 {
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
var inputIndices: ${input.type.indices};
inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1));
inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1));
Expand All @@ -328,10 +332,10 @@ const bilinearInterpolation =
return input[${input.indicesToOffset('inputIndices')}];
}

fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 {
fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices);
var row:f32 = originalIndices[${heightIdx}];
var col:f32 = originalIndices[${widthIdx}];
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
Expand All @@ -348,14 +352,14 @@ const bilinearInterpolation =
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var x11: f32 = getInputValue(batch, channel, row1, col1);
var x12: f32 = getInputValue(batch, channel, row1, col2);
var x21: f32 = getInputValue(batch, channel, row2, col1);
var x22: f32 = getInputValue(batch, channel, row2, col2);
var dx1: f32 = row - f32(row1);
var dx2: f32 = f32(row2 ) - row;
var dy1 = col - f32(col1);
var dy2 = f32(col2) - col;
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
var dx1: ${dType} = row - ${dType}(row1);
var dx2: ${dType} = ${dType}(row2) - row;
var dy1 = col - ${dType}(col1);
var dy2 = ${dType}(col2) - col;
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
Expand All @@ -365,24 +369,24 @@ const bicubicInterpolation =
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
extrapolationValue: number, excludeOutside: boolean): string => {
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];

const dType = input.type.value;
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
return `
fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${
output.type.indices}) -> f32 {
output.type.indices}) -> ${dType} {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`};
var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]},
f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: f32 = originalIdx - floor(originalIdx);
var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]},
${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);

if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
return ${extrapolationValue};
}
var data: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: f32 = originalIdx + f32(i);
var ${direction}: ${dType} = originalIdx + ${dType}(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
Expand All @@ -405,12 +409,12 @@ const bicubicInterpolation =
return `
${createCubicInterpolationFunction(heightIdx)};
${createCubicInterpolationFunction(widthIdx)};
fn getCubicInterpolationCoefs(s: f32) -> array<f32, 4> {
fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> {
var absS = abs(s);
var coeffs: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: f32 = 1.0 - absS;
var twoMinusAbsS: f32 = 2.0 - absS;
var onePlusAbsS: f32 = 1.0 + absS;
var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: ${dType} = 1.0 - absS;
var twoMinusAbsS: ${dType} = 2.0 - absS;
var onePlusAbsS: ${dType} = 1.0 + absS;
coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA};
coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
Expand All @@ -420,12 +424,12 @@ const bicubicInterpolation =
return coeffs;
}

fn cubicInterpolation1D(x: array<f32, 4>, coefs: array<f32, 4>) -> f32 {
var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3];
fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} {
var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3];
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
}

fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 {
fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
var inputIndices: ${input.type.indices} = outputIndices;
return colCubicInterpolation(inputIndices, outputIndices);
}
Expand All @@ -451,15 +455,16 @@ const createResizeProgramInfo =
const outputSize = ShapeUtil.size(outputShape);
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const dataType = input.type.value;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${noScale ? '' : `
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)};
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)};
${(() => {
switch (attributes.mode) {
case 'nearest':
return `
${checkInputIndices(input, inputShape)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)};
${
calculateInputIndicesFromOutputIndices(
input, output, inputShape, outputShape, scales, roi, useExtrapolation)};
Expand Down
Loading