From 99b0e19f1155abd09ac628091639c4d50c154595 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 29 Apr 2024 14:27:21 -0700 Subject: [PATCH] [JS/WebGPU] MatMulNBits remove unnecessary condition (#20396) Distribute writing-to-output work over all threads in MatMulNBits. ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 7f1a5b96863f7..8aabaeb22f4d4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -208,15 +208,17 @@ export const createMatMulNBitsProgramInfo = ${processOneBlock} } workgroupBarrier(); - if (local_id.x == 0u) { - var output_indices: ${output.type.indices}; - ${output.indicesSet('output_indices', '0', 'batch')}; - ${output.indicesSet('output_indices', outputRank - 1, 'col')}; - ${output.indicesSet('output_indices', outputRank - 2, '0')}; - var output_offset = ${output.indicesToOffset('output_indices')}; - for (var m: u32 = 0u; m < ${dimAOuter}u; m++) { + var output_indices: ${output.type.indices}; + var elements_per_thread: u32 = ${Math.ceil(dimAOuter / nBlocksPerCol)}; + ${output.indicesSet('output_indices', '0', 'batch')}; + ${output.indicesSet('output_indices', outputRank - 1, 'col')}; + ${output.indicesSet('output_indices', outputRank - 2, 'local_id.x * elements_per_thread')}; + var output_offset = ${output.indicesToOffset('output_indices')}; + for (var m: u32 = 0u; m < elements_per_thread; m++) { + var row = m + local_id.x * elements_per_thread; + if (row < ${dimAOuter}) { var output_value: ${output.type.value} = ${output.type.value}(0); - var workgroup_shared_offset: u32 = m; + var workgroup_shared_offset: u32 = row; for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) { output_value += workgroup_shared[workgroup_shared_offset]; workgroup_shared_offset += ${dimAOuter};