diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 7f1a5b96863f7..8aabaeb22f4d4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -208,15 +208,17 @@ export const createMatMulNBitsProgramInfo = ${processOneBlock} } workgroupBarrier(); - if (local_id.x == 0u) { - var output_indices: ${output.type.indices}; - ${output.indicesSet('output_indices', '0', 'batch')}; - ${output.indicesSet('output_indices', outputRank - 1, 'col')}; - ${output.indicesSet('output_indices', outputRank - 2, '0')}; - var output_offset = ${output.indicesToOffset('output_indices')}; - for (var m: u32 = 0u; m < ${dimAOuter}u; m++) { + var output_indices: ${output.type.indices}; + var elements_per_thread: u32 = ${Math.ceil(dimAOuter / nBlocksPerCol)}; + ${output.indicesSet('output_indices', '0', 'batch')}; + ${output.indicesSet('output_indices', outputRank - 1, 'col')}; + ${output.indicesSet('output_indices', outputRank - 2, 'local_id.x * elements_per_thread')}; + var output_offset = ${output.indicesToOffset('output_indices')}; + for (var m: u32 = 0u; m < elements_per_thread; m++) { + var row = m + local_id.x * elements_per_thread; + if (row < ${dimAOuter}) { var output_value: ${output.type.value} = ${output.type.value}(0); - var workgroup_shared_offset: u32 = m; + var workgroup_shared_offset: u32 = row; for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) { output_value += workgroup_shared[workgroup_shared_offset]; workgroup_shared_offset += ${dimAOuter};