Skip to content

Commit

Permalink
More change to fix bilinear/trilinear Interpolation work.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Dec 22, 2023
1 parent 825bde3 commit 531d65f
Showing 1 changed file with 81 additions and 45 deletions.
126 changes: 81 additions & 45 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,11 @@ const setChannelAndBatchIndices =
'';

const bilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
extrapolationValue: number): string => {
const isNchw = true;
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
Expand All @@ -339,10 +340,12 @@ const bilinearInterpolation =
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
${
useExtrapolation ? `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${
inputShape[widthIdx]} - 1))) {
return ${extrapolationValue};
}
}` :
''};
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
var row1: u32 = u32(row);
Expand All @@ -355,10 +358,18 @@ const bilinearInterpolation =
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
var dx1: ${dType} = row - ${dType}(row1);
var dx2: ${dType} = ${dType}(row2) - row;
var dy1 = col - ${dType}(col1);
var dy2 = ${dType}(col2) - col;
var dx1: ${dType} = abs(row - ${dType}(row1));
var dx2: ${dType} = abs(${dType}(row2) - row);
var dy1: ${dType} = abs(col - ${dType}(col1));
var dy2: ${dType} = abs(${dType}(col2) - col);
if (row1 == row2) {
dx1 = 0.5;
dx2 = 0.5;
}
if (col1 == col2) {
dy1 = 0.5;
dy2 = 0.5;
}
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
Expand All @@ -367,7 +378,9 @@ const bicubicInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
extrapolationValue: number, excludeOutside: boolean): string => {
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];
const is2D = inputShape.length === 2;
const isNchw = true;
const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2];
const dType = input.type.value;
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
Expand All @@ -387,16 +400,18 @@ const bicubicInterpolation =
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: ${dType} = originalIdx + ${dType}(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
continue;
} else if (${useExtrapolation}) {
return ${extrapolationValue};
} else {
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
}
${(() => {
if (excludeOutside) {
return `coefs[i + 1] = 0.0;
continue;`;
} else if (useExtrapolation) {
return `return ${extrapolationValue};`;
} else {
return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`;
}
})()};
}
var input_indices_copy: ${input.type.indices} = input_indices;
var input_indices_copy: ${input.type.indices} = input_indices;
${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
data[i + 1] = ${
idx === heightIdx ? input.getByIndices('input_indices_copy') :
Expand Down Expand Up @@ -437,10 +452,11 @@ const bicubicInterpolation =
};

const trilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
extrapolationValue: number): string => {
const isNchw = true;
const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (scales[1] === 1.0 ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]);
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
Expand All @@ -457,11 +473,14 @@ const trilinearInterpolation =
var depth:${dType} = originalIndices[${depthIdx}];
var height:${dType} = originalIndices[${heightIdx}];
var width:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
inputShape[heightIdx]} - 1) || width < 0 || width > ${inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
}
depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
${
useExtrapolation ? `(depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) {
return ${extrapolationValue};
}` :
''};
depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
height = max(0, min(height, ${inputShape[heightIdx]} - 1));
width = max(0, min(width, ${inputShape[widthIdx]} - 1));
var depth1: u32 = u32(depth);
Expand All @@ -481,12 +500,24 @@ const trilinearInterpolation =
var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
var dx1 = depth - ${dType}(depth1);
var dx2 = ${dType}(depth2) - depth;
var dy1 = height - ${dType}(height1);
var dy2 = ${dType}(height2) - height;
var dz1: ${dType} = width - ${dType}(width1);
var dz2: ${dType} = ${dType}(width2) - width;
var dx1: ${dType} = abs(depth - ${dType}(depth1));
var dx2: ${dType} = abs(${dType}(depth2) - depth);
var dy1: ${dType} = abs(height - ${dType}(height1));
var dy2: ${dType} = abs(${dType}(height2) - height);
var dz1: ${dType} = abs(width - ${dType}(width1));
var dz2: ${dType} = abs(${dType}(width2) - width);
if (depth1 == depth2) {
dx1 = 0.5;
dx2 = 0.5;
}
if (height1 == height2) {
dy1 = 0.5;
dy2 = 0.5;
}
if (width1 == width2) {
dx1 = 0.5;
dx2 = 0.5;
}
return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
}`;
Expand Down Expand Up @@ -531,23 +562,28 @@ const createResizeProgramInfo =
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${
bilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}`;
bilinearInterpolation(input, output, inputShape, useExtrapolation, attributes.extrapolationValue)}`;
} else if (inputShape.length === 3 || inputShape.length === 5) {
return `${
trilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}`;
input, output, inputShape, useExtrapolation, attributes.extrapolationValue)}`;
} else {
throw Error('Only input dims 2, 3, 4 and 5 are supported in linear mode.');
throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.');
}
})()};
`;
case 'cubic':
return `
${
bicubicInterpolation(
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)};
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${
bicubicInterpolation(
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)}`;
} else {
throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.');
}
})()};
`;
default:
throw Error('Invalid resize mode');
Expand Down Expand Up @@ -575,8 +611,8 @@ const createResizeProgramInfo =
}`;
case 'linear':
return `output[global_idx] = ${
inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' :
'trilinearInterpolation'}(output_indices);`;
(inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' :
'trilinearInterpolation'}(output_indices);`;
case 'cubic':
return 'output[global_idx] = bicubicInterpolation(output_indices);';
default:
Expand All @@ -590,7 +626,7 @@ const createResizeProgramInfo =
name: 'Resize',
shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`,
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
inputDependencies: ['rank']
},
getShaderSource,
Expand Down

0 comments on commit 531d65f

Please sign in to comment.