Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JS/WebGPU] Add trilinear interpolation to Resize; activation_params attribute is optional for FusedConv also. #18842

Merged
merged 12 commits into from
Dec 28, 2023
190 changes: 148 additions & 42 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ const initOutputShape =
return outputShape;
};

const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => {
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => {
const scaleInPolicy = (() => {
switch (attributes.keepAspectRatioPolicy) {
case 'not_larger':
Expand Down Expand Up @@ -312,52 +312,64 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]):
return true;
}`;

const setChannelAndBatchIndices =
(input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string =>
input.rank > spacialDims ? `
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
` :
'';

const bilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
extrapolationValue: number): string => {
const isNchw = true;
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
var input_indices: ${input.type.indices};
${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
if (${inputShape.length} > 2) {
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
};
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)}
return ${input.getByIndices('input_indices')};
}

fn bilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
${
useExtrapolation ?
`if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1))) {
return ${extrapolationValue};
}
}` :
''};
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
var row1: u32 = u32(row);
var col1: u32 = u32(col);
var row2: u32 = u32(row + 1);
var col2: u32 = u32(col + 1);
var channel: u32 = 0;
var batch: u32 = 0;
if (${inputShape.length > 2}) {
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'};
var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'};
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
var dx1: ${dType} = row - ${dType}(row1);
var dx2: ${dType} = ${dType}(row2) - row;
var dy1 = col - ${dType}(col1);
var dy2 = ${dType}(col2) - col;
var dx1: ${dType} = abs(row - ${dType}(row1));
var dx2: ${dType} = abs(${dType}(row2) - row);
var dy1: ${dType} = abs(col - ${dType}(col1));
var dy2: ${dType} = abs(${dType}(col2) - col);
if (row1 == row2) {
dx1 = 0.5;
dx2 = 0.5;
}
if (col1 == col2) {
dy1 = 0.5;
dy2 = 0.5;
}
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
Expand All @@ -366,7 +378,9 @@ const bicubicInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
extrapolationValue: number, excludeOutside: boolean): string => {
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];
const is2D = inputShape.length === 2;
const isNchw = true;
const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2];
const dType = input.type.value;
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
Expand All @@ -386,16 +400,18 @@ const bicubicInterpolation =
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: ${dType} = originalIdx + ${dType}(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
continue;
} else if (${useExtrapolation}) {
return ${extrapolationValue};
} else {
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
}
${(() => {
if (excludeOutside) {
return `coefs[i + 1] = 0.0;
continue;`;
} else if (useExtrapolation) {
return `return ${extrapolationValue};`;
} else {
return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`;
}
var input_indices_copy: ${input.type.indices} = input_indices;
})()};
}
var input_indices_copy: ${input.type.indices} = input_indices;
${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
data[i + 1] = ${
idx === heightIdx ? input.getByIndices('input_indices_copy') :
Expand Down Expand Up @@ -435,6 +451,78 @@ const bicubicInterpolation =
`;
};

const trilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
extrapolationValue: number): string => {
const isNchw = true;
const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
var input_indices: ${input.type.indices};
${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)}
return ${input.getByIndices('input_indices')};
}

fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var depth:${dType} = originalIndices[${depthIdx}];
var height:${dType} = originalIndices[${heightIdx}];
var width:${dType} = originalIndices[${widthIdx}];
${
useExtrapolation ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) {
return ${extrapolationValue};
}` :
''};

depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
height = max(0, min(height, ${inputShape[heightIdx]} - 1));
width = max(0, min(width, ${inputShape[widthIdx]} - 1));
var depth1: u32 = u32(depth);
var height1: u32 = u32(height);
var width1: u32 = u32(width);
var depth2: u32 = u32(depth + 1);
var height2: u32 = u32(height + 1);
var width2: u32 = u32(width + 1);
var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'};
var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'};

var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2);
var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1);
var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
var dx1: ${dType} = abs(depth - ${dType}(depth1));
var dx2: ${dType} = abs(${dType}(depth2) - depth);
var dy1: ${dType} = abs(height - ${dType}(height1));
var dy2: ${dType} = abs(${dType}(height2) - height);
var dz1: ${dType} = abs(width - ${dType}(width1));
var dz2: ${dType} = abs(${dType}(width2) - width);
if (depth1 == depth2) {
dx1 = 0.5;
dx2 = 0.5;
}
if (height1 == height2) {
dy1 = 0.5;
dy2 = 0.5;
}
if (width1 == width2) {
dz1 = 0.5;
dz2 = 0.5;
}
return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
}`;
};

const createResizeProgramInfo =
(inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[],
sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => {
Expand All @@ -454,6 +542,7 @@ const createResizeProgramInfo =
const outputSize = ShapeUtil.size(outputShape);
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const extrapolationValue = attributes.extrapolationValue;
const dataType = input.type.value;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${noScale ? '' : `
Expand All @@ -471,16 +560,28 @@ const createResizeProgramInfo =
case 'linear':
return `
${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
${
bilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)};
`;
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
} else if (inputShape.length === 3 || inputShape.length === 5) {
return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
} else {
throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.');
}
})()};
`;
case 'cubic':
return `
${
bicubicInterpolation(
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)};
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${
bicubicInterpolation(
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)}`;
} else {
throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.');
}
})()};
`;
default:
throw Error('Invalid resize mode');
Expand All @@ -507,21 +608,23 @@ const createResizeProgramInfo =
output[global_idx] = ${attributes.extrapolationValue};
}`;
case 'linear':
return 'output[global_idx] = bilinearInterpolation(output_indices);';
return `output[global_idx] = ${
(inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' :
'trilinearInterpolation'}(output_indices);`;
case 'cubic':
return 'output[global_idx] = bicubicInterpolation(output_indices);';
default:
throw Error(`Unsupported resize mode: ${attributes.mode}`);
}
})()};
`}
`}
}`;

return {
name: 'Resize',
shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`,
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
inputDependencies: ['rank']
},
getShaderSource,
Expand Down Expand Up @@ -551,6 +654,9 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v
const sizes: number[] = [];
const roi: number[] = [];
const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
if (attributes.antialias !== 0) {
throw Error('Only default value (0) for Antialias attribute is supported');
}
validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
context.compute(
createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]});
Expand Down
Loading
Loading