Skip to content

Commit

Permalink
Miscellanious fixes to bilinear/triliner interpolation functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Dec 20, 2023
1 parent 35bbc48 commit 825bde3
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]):
return true;
}`;

const setChannelAndBatchIndices =
(input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string =>
input.rank > spacialDims ? `
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
` :
'';

const bilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
Expand All @@ -323,10 +331,7 @@ const bilinearInterpolation =
var input_indices: ${input.type.indices};
${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
if (${inputShape.length} > 2) {
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
};
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)}
return ${input.getByIndices('input_indices')};
}
Expand All @@ -344,12 +349,8 @@ const bilinearInterpolation =
var col1: u32 = u32(col);
var row2: u32 = u32(row + 1);
var col2: u32 = u32(col + 1);
var channel: u32 = 0;
var batch: u32 = 0;
if (${inputShape.length > 2}) {
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'};
var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'};
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
Expand Down Expand Up @@ -447,10 +448,7 @@ const trilinearInterpolation =
${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
if (${inputShape.length} > 3) {
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
};
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)}
return ${input.getByIndices('input_indices')};
}
Expand All @@ -463,21 +461,18 @@ const trilinearInterpolation =
inputShape[heightIdx]} - 1) || width < 0 || width > ${inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
}
depth = max(0, min(row, ${inputShape[depthIdx]} - 1));
height = max(0, min(row, ${inputShape[heightIdx]} - 1));
width = max(0, min(col, ${inputShape[widthIdx]} - 1));
depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
height = max(0, min(height, ${inputShape[heightIdx]} - 1));
width = max(0, min(width, ${inputShape[widthIdx]} - 1));
var depth1: u32 = u32(depth);
var height1: u32 = u32(height);
var width1: u32 = u32(width);
var depth2: u32 = u32(depth + 1);
var height2: u32 = u32(height + 1);
var width2: u32 = u32(width + 1);
var channel: u32 = 0;
var batch: u32 = 0;
if (${inputShape.length > 2}) {
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'};
var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'};
var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
Expand Down Expand Up @@ -579,7 +574,9 @@ const createResizeProgramInfo =
output[global_idx] = ${attributes.extrapolationValue};
}`;
case 'linear':
return 'output[global_idx] = bilinearInterpolation(output_indices);';
return `output[global_idx] = ${
inputShape.length === 2 || inputShape.length === 4 ? 'bilinearInterpolation' :
'trilinearInterpolation'}(output_indices);`;
case 'cubic':
return 'output[global_idx] = bicubicInterpolation(output_indices);';
default:
Expand Down

0 comments on commit 825bde3

Please sign in to comment.