Skip to content

Commit

Permalink
Resize & BiasSplitGelu fp16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Nov 21, 2023
1 parent 97cc40d commit 70b9b72
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 75 deletions.
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {erfImpl} from './unary-op';

const validateInputs = (inputs: readonly TensorView[]): void => {
Expand Down Expand Up @@ -35,14 +35,15 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI
const output = outputVariable('output', inputs[0].dataType, outputShape, 4);

const outputSize = ShapeUtil.size(outputShape) / 4;
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);

const getShaderSource = (shaderHelper: ShaderHelper) => `
const M_SQRT2 = sqrt(2.0);
const halfChannels = ${inputs[0].dims[2] / 4 / 2}u;
${shaderHelper.declareVariables(input, bias, output)}
${erfImpl('vec4f')}
${erfImpl(`vec4<${dataType}>`, dataType)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
Expand Down
151 changes: 78 additions & 73 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,50 +105,51 @@ const validateInputs =
}
};

const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string =>
'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\
lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' +
const getOriginalCoordinateFromResizedCoordinate =
(coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
`fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType},
lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
return 'return xResized / xScale;';
case 'pytorch_half_pixel':
return 'if (lengthResized > 1) { \
switch (coordinateTransferMode) {
case 'asymmetric':
return 'return xResized / xScale;';
case 'pytorch_half_pixel':
return 'if (lengthResized > 1) { \
return (xResized + 0.5) / xScale - 0.5; \
} else { \
return 0.0; \
}';
case 'tf_half_pixel_for_nn':
return 'return (xResized + 0.5) / xScale;';
case 'align_corners':
return 'if (lengthResized == 1) { \
case 'tf_half_pixel_for_nn':
return 'return (xResized + 0.5) / xScale;';
case 'align_corners':
return 'if (lengthResized == 1) { \
return 0.0; \
} else { \
return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
}';
case 'tf_crop_and_resize':
return 'if (lengthResized > 1) { \
case 'tf_crop_and_resize':
return `if (lengthResized > 1) { \
return roiStart * (lengthOriginal - 1) + \
(xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
} else { \
return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \
}';
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
].join('\n');
case 'half_pixel':
return 'return ((xResized + 0.5) / xScale) - 0.5;';
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
})() +
return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \
}`;
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
].join('\n');
case 'half_pixel':
return 'return ((xResized + 0.5) / xScale) - 0.5;';
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
})() +
'}';

const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string =>
'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => {
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number, dType: string): string =>
`fn getNearestPixelFromOriginal(xOriginal: ${dType}, isDownSample: bool) -> ${dType} {` + (() => {
switch (nearestMode) {
case 'round_prefer_ceil':
return 'if (fract(xOriginal) == 0.5) { \
Expand Down Expand Up @@ -246,20 +247,21 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr
const calculateOriginalIndicesFromOutputIndices =
(output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[],
roi: readonly number[]): string => `
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<f32, ${
outputShape.length}> {
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${
output.type.value}, ${outputShape.length}> {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<f32, ${outputShape.length}>;
const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
if (scales[i] == 1.0) {
originalIndices[i] = f32(outputIndex);
originalIndices[i] = ${output.type.value}(outputIndex);
} else {
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i],
${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${
inputShape.length}]);
}
}
return originalIndices;
Expand All @@ -271,21 +273,22 @@ const calculateInputIndicesFromOutputIndices =
fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')});
const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')});
var inputIndices: ${input.type.indices};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex: u32;
if (scales[i] == 1.0) {
inputIndex = outputIndex;
} else {
var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) {
var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i],
${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${
inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) {
if (original_idx < 0) {
inputIndex = 0;
} else if (original_idx > (f32(inputShape[i]) - 1)) {
} else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) {
inputIndex = inputShape[i] - 1;
} else {
inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1));
Expand Down Expand Up @@ -316,8 +319,9 @@ const bilinearInterpolation =
useExtrapolation: boolean, extrapolationValue: number): string => {
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 {
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
var inputIndices: ${input.type.indices};
inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1));
inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1));
Expand All @@ -328,10 +332,10 @@ const bilinearInterpolation =
return input[${input.indicesToOffset('inputIndices')}];
}
fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> f32 {
fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices);
var row:f32 = originalIndices[${heightIdx}];
var col:f32 = originalIndices[${widthIdx}];
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
Expand All @@ -348,14 +352,14 @@ const bilinearInterpolation =
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var x11: f32 = getInputValue(batch, channel, row1, col1);
var x12: f32 = getInputValue(batch, channel, row1, col2);
var x21: f32 = getInputValue(batch, channel, row2, col1);
var x22: f32 = getInputValue(batch, channel, row2, col2);
var dx1: f32 = row - f32(row1);
var dx2: f32 = f32(row2 ) - row;
var dy1 = col - f32(col1);
var dy2 = f32(col2) - col;
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
var dx1: ${dType} = row - ${dType}(row1);
var dx2: ${dType} = ${dType}(row2) - row;
var dy1 = col - ${dType}(col1);
var dy2 = ${dType}(col2) - col;
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
Expand All @@ -365,24 +369,24 @@ const bicubicInterpolation =
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
extrapolationValue: number, excludeOutside: boolean): string => {
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];

const dType = input.type.value;
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
return `
fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${
output.type.indices}) -> f32 {
output.type.indices}) -> ${dType} {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`};
var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]},
f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: f32 = originalIdx - floor(originalIdx);
var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]},
${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
return ${extrapolationValue};
}
var data: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var data: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: f32 = originalIdx + f32(i);
var ${direction}: ${dType} = originalIdx + ${dType}(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
Expand All @@ -405,12 +409,12 @@ const bicubicInterpolation =
return `
${createCubicInterpolationFunction(heightIdx)};
${createCubicInterpolationFunction(widthIdx)};
fn getCubicInterpolationCoefs(s: f32) -> array<f32, 4> {
fn getCubicInterpolationCoefs(s: ${dType}) -> array<${dType}, 4> {
var absS = abs(s);
var coeffs: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: f32 = 1.0 - absS;
var twoMinusAbsS: f32 = 2.0 - absS;
var onePlusAbsS: f32 = 1.0 + absS;
var coeffs: array<${dType}, 4> = array<${dType}, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: ${dType} = 1.0 - absS;
var twoMinusAbsS: ${dType} = 2.0 - absS;
var onePlusAbsS: ${dType} = 1.0 + absS;
coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA};
coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
Expand All @@ -420,12 +424,12 @@ const bicubicInterpolation =
return coeffs;
}
fn cubicInterpolation1D(x: array<f32, 4>, coefs: array<f32, 4>) -> f32 {
var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3];
fn cubicInterpolation1D(x: array<${dType}, 4>, coefs: array<${dType}, 4>) -> ${dType} {
var coefsSum: ${dType} = coefs[0] + coefs[1] + coefs[2] + coefs[3];
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
}
fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> f32 {
fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
var inputIndices: ${input.type.indices} = outputIndices;
return colCubicInterpolation(inputIndices, outputIndices);
}
Expand All @@ -451,15 +455,16 @@ const createResizeProgramInfo =
const outputSize = ShapeUtil.size(outputShape);
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const dataType = input.type.value;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${noScale ? '' : `
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)};
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode, dataType)};
${(() => {
switch (attributes.mode) {
case 'nearest':
return `
${checkInputIndices(input, inputShape)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)};
${
calculateInputIndicesFromOutputIndices(
input, output, inputShape, outputShape, scales, roi, useExtrapolation)};
Expand Down

0 comments on commit 70b9b72

Please sign in to comment.