Skip to content

Commit

Permalink
[JS/Web] Sajandhy/webgpu resize scales rank check (#18954)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Dec 29, 2023
1 parent 96d1f32 commit 780fc36
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
11 changes: 7 additions & 4 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ const validateScales = (scales: number[], attributes: ResizeAttributes): void =>
// Check scales dims based on mode: LINEAR, CUBIC
if (scales.length > 0) {
if (attributes.mode === 'linear') {
if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
(scales.length === 4 && scales[0] === 1 && scales[3] === 1))) {
throw new Error('Resize requires scales input size to be 2 or 4 for linear mode');
if (!(scales.length === 2 || scales.length === 3 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
(scales.length === 4 && scales[0] === 1 && scales[3] === 1) ||
(scales.length === 5 && scales[0] === 1 && scales[1] === 1))) {
throw new Error(
`For linear mode, Resize requires scales to be 2D, 3D, 4D with either two outermost or one innermost and
one outermost scale values equal to 1, or 5D with two outermost scale values equal to 1`);
}
} else if (attributes.mode === 'cubic') {
if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
Expand Down Expand Up @@ -475,7 +478,7 @@ const trilinearInterpolation =
var width:${dType} = originalIndices[${widthIdx}];
${
useExtrapolation ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) {
inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
}` :
''};
Expand Down
47 changes: 47 additions & 0 deletions js/web/test/data/ops/resize.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"name": "Resize - 5D Trilinear",
"operator": "Resize",
// "opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "mode", "data": "linear", "type": "string" },
{ "name": "coordinate_transformation_mode", "data": "tf_crop_and_resize", "type": "string" },
{ "name": "extrapolation_value", "data": 10, "type": "float" }
],
"cases": [
{
"name": "X",
"inputs": [
{
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
"dims": [1, 2, 1, 2, 2],
"type": "float32"
},
{
"data": [0, 0, 0, 0, 0, 1, 2, 1, 2, 2],
"dims": [10],
"type": "float32"
},
{
"data": [1, 1, 1, 2, 4],
"dims": [5],
"type": "float32"
}
],
"outputs": [
{
"data": [
1, 1.571428656578064, 2.142857313156128, 2.7142856121063232, 10, 10, 10, 10, 2.3333332538604736,
2.9047622680664062, 3.4761905670166016, 4.047618865966797, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 3, 3.5714287757873535, 4.142857074737549, 4.714285373687744, 10, 10, 10,
10, 5.6666669845581055, 6.238095760345459, 6.809524059295654, 7.380952835083008, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10
],
"dims": [1, 2, 1, 4, 8],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 780fc36

Please sign in to comment.