Skip to content

Commit

Permalink
[JS/WebGPU] fix an error in Clip (microsoft#18799)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Check whether the min/max inputs are provided and use default values if not provided.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Dec 19, 2023
1 parent 32fcf73 commit 98510fb
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -772,14 +772,14 @@ class ShaderHelperImpl implements ShaderHelper {
const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>` :
`@builtin(local_invocation_index) local_index : u32,
`@builtin(local_invocation_index) local_idx : u32,
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
const globalIdxDefinition = is1DimensionDispatch ?
'let global_idx = global_id.x;' :
'let global_idx = global_id.x; let local_idx = local_id.x;' :
`let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`;
workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`;

return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
fn main(${paramList}) {
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/gemm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let m = global_id.x / N;
let n = global_id.x % N;
let m = global_idx / N;
let n = global_idx % N;
var value = ${dataType}(0);
for (var k: u32 = 0u; k<${K}u; k++) {
Expand All @@ -107,7 +107,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
${calculateAlpha}
${calculateC}
output[global_id.x] = value;
output[global_idx] = value;
}`;
return {
Expand Down
1 change: 0 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ export const createReduceSharedProgramInfo =
return ((a - 1u) / b + 1u);
}
${shaderHelper.mainStart(workgroupSize)}
let local_idx = local_id.x;
let outputIndex = global_idx / ${workgroupSize};
let offset = outputIndex * uniforms.reduceSize;
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
}
${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)}
${shaderHelper.mainStart()}
let gindex = i32(global_id.x);
let lindex = i32(local_id.x);
let gindex = i32(global_idx);
let lindex = i32(local_idx);
const wg = ${WG};
let row = gindex / wg;
let cols = uniforms.packedCols;
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}

const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
const min = (inputs.length >= 2 && inputs[1].data !== 0) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3 && inputs[2].data !== 0) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};

Expand Down

0 comments on commit 98510fb

Please sign in to comment.