diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 5fffa2f266603..0eb0d40a3ea5e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -772,14 +772,14 @@ class ShaderHelperImpl implements ShaderHelper { const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_index : u32, + `@builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? - 'let global_idx = global_id.x;' : + 'let global_idx = global_id.x; let local_idx = local_id.x;' : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ - workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; + workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) fn main(${paramList}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index 6e9dee41ce488..1c5d28e4b8e3f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -97,8 +97,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let m = global_id.x / N; - let n = global_id.x % N; + let m = global_idx / N; + let n = global_idx % N; var value = ${dataType}(0); for (var k: u32 = 0u; k<${K}u; k++) { @@ -107,7 +107,7 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt ${calculateAlpha} ${calculateC} - output[global_id.x] = value; + output[global_idx] = value; }`; return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts index 1365d1e9a12a4..7c440cbffea7b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts @@ -141,7 +141,6 @@ export const createReduceSharedProgramInfo = return ((a - 1u) / b + 1u); } ${shaderHelper.mainStart(workgroupSize)} - let local_idx = local_id.x; let outputIndex = global_idx / ${workgroupSize}; let offset = outputIndex * uniforms.reduceSize; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 378a7e738dac9..324dc3af1a710 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -73,8 +73,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut } ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)} ${shaderHelper.mainStart()} - let gindex = i32(global_id.x); - let lindex = i32(local_id.x); + let gindex = i32(global_idx); + let lindex = i32(local_idx); const wg = ${WG}; let row = gindex / wg; let cols = uniforms.packedCols; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 51114d8a99dd1..a25e7fe4229b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -125,8 +125,8 @@ export interface ClipAttributes extends AttributeWithCacheKey { } const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + const min = (inputs.length >= 2 && inputs[1].data !== 0) ? inputs[1].getFloat32Array()[0] : MIN_CLIP; + const max = (inputs.length >= 3 && inputs[2].data !== 0) ? inputs[2].getFloat32Array()[0] : MAX_CLIP; return createAttributeWithCacheKey({min, max}); };