Skip to content

Commit

Permalink
Added Trilinear Interpolation.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Dec 16, 2023
1 parent c95f1d1 commit a66278d
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,68 @@ const bicubicInterpolation =
`;
};

const trilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (scales[1] === 1.0 ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
var input_indices: ${input.type.indices};
${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
if (${inputShape.length} > 3) {
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
};
return ${input.getByIndices('input_indices')};
}
fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var depth:${dType} = originalIndices[${depthIdx}];
var height:${dType} = originalIndices[${heightIdx}];
var width:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
inputShape[heightIdx]} - 1) || width < 0 || width > ${inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
}
depth = max(0, min(row, ${inputShape[depthIdx]} - 1));
height = max(0, min(row, ${inputShape[heightIdx]} - 1));
width = max(0, min(col, ${inputShape[widthIdx]} - 1));
var depth1: u32 = u32(depth);
var height1: u32 = u32(height);
var width1: u32 = u32(width);
var depth2: u32 = u32(depth + 1);
var height2: u32 = u32(height + 1);
var width2: u32 = u32(width + 1);
var channel: u32 = 0;
var batch: u32 = 0;
if (${inputShape.length > 2}) {
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2);
var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1);
var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
var dx1 = depth - ${dType}(depth1);
var dx2 = ${dType}(depth2) - depth;
var dy1 = height - ${dType}(height1);
var dy2 = ${dType}(height2) - height;
var dz1: ${dType} = width - ${dType}(width1);
var dz2: ${dType} = ${dType}(width2) - width;
return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
}`;
};

const createResizeProgramInfo =
(inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[],
sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => {
Expand Down Expand Up @@ -471,10 +533,20 @@ const createResizeProgramInfo =
case 'linear':
return `
${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
${
bilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)};
`;
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${
bilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}`;
} else if (inputShape.length === 3 || inputShape.length === 5) {
return `${
trilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}`;
} else {
throw Error('Only input dims 2, 3, 4 and 5 are supported in linear mode.');
}
})()};
`;
case 'cubic':
return `
${
Expand Down Expand Up @@ -514,7 +586,7 @@ const createResizeProgramInfo =
throw Error(`Unsupported resize mode: ${attributes.mode}`);
}
})()};
`}
`}
}`;

return {
Expand Down

0 comments on commit a66278d

Please sign in to comment.