Skip to content

Commit

Permalink
[js/webgpu] Fix Expand/Gather when input type is bool (#18999)
Browse files Browse the repository at this point in the history
### Description
Also update the op test suite.

### Motivation and Context
Previously the *total* size in case `Expand - last dim is not divisible
by 4` was a multiple of 4, even though the *last dimension* was not, so
the bug has never been caught.
  • Loading branch information
hujiajie authored Jan 5, 2024
1 parent 7f0aac0 commit 447a3a7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 7 deletions.
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const outputShape: number[] = calculateOutputShape(inputShape, shape);
const dataType = inputs[0].dataType;
const components = dataType === DataType.bool ? 4 : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);

const enableInputShapeUniform = enableShapesUniforms(inputShape.length);
const enableOutputShapeUniform = enableShapesUniforms(outputShape.length);
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath

const axisDimLimit = inputShape[axis];
const components = inputs[0].dataType === DataType.bool ? 4 : 1;
const outputSize = ShapeUtil.size(outputShape) / components;
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);

const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length);
const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims;
Expand Down
29 changes: 24 additions & 5 deletions js/web/test/data/ops/expand.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,39 @@
"name": "Expand - last dim is not divisible by 4",
"inputs": [
{
"data": [true, false, false, true, true, true, false, false, false, true, true, true],
"dims": [2, 6],
"data": [true, false, false, true, true, true],
"dims": [1, 6],
"type": "bool"
},
{
"data": [2, 1],
"data": [3, 1],
"dims": [2],
"type": "int64"
}
],
"outputs": [
{
"data": [true, false, false, true, true, true, false, false, false, true, true, true],
"dims": [2, 6],
"data": [
true,
false,
false,
true,
true,
true,
true,
false,
false,
true,
true,
true,
true,
false,
false,
true,
true,
true
],
"dims": [3, 6],
"type": "bool"
}
]
Expand Down
22 changes: 22 additions & 0 deletions js/web/test/data/ops/gather.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@
"operator": "Gather",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [false, true, false, false],
"dims": [4],
"type": "bool"
},
{
"data": [1],
"dims": [],
"type": "int32"
}
],
"outputs": [
{
"data": [true],
"dims": [],
"type": "bool"
}
]
},
{
"name": "data[2,4] indices[1]",
"inputs": [
Expand Down

0 comments on commit 447a3a7

Please sign in to comment.