From 401938fdef76c70d5c2a0537e64db8d7688d68a8 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 10 Dec 2024 15:04:43 -0800 Subject: [PATCH] Improves 2d tiled matmulnbits by repeating A, loads N times for each B load --- .../webgpu/quantization/matmul_nbits.cc | 39 ++++++++++++++----- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index be18f820e2747..919054796d42c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -334,11 +334,13 @@ Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const // Note in matmulnbits, B matrix is already transposed, however the following remains true // for the shader below M describes A, N describes B and K is the hidden/shared dimension. // K4/K8 are simply K divided by 4 or 8 respectively. + // A_REPEAT, number of times each workgroup reloads A sharing B. shader.AdditionalImplementation() << R"INIT_SECTION( // Matrix dimensions and quantization parameters const TILE_SIZE : u32 = 16u; const VALUES_PER_VEC4 : u32 = 4u; const QUANTIZATION_BLOCK_SIZE : u32 = 32; +const A_REPEAT : u32 = 8u; // We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM, // so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time. // This uses all 16 lanes on 12th gen intel chips. @@ -349,13 +351,10 @@ const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / //Shared memory var tile_A : array, TILE_SIZE>; var tile_B : array, TILE_SIZE>; -var tile_O : array, TILE_SIZE>; +var tile_O : array, TILE_SIZE * A_REPEAT>; fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32) { - if (a_global >= uniforms.M) { - return; - } let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id]; tile_A[slot][parallel_id] = local_A; } @@ -417,21 +416,36 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t // a single wave in this approach of indexing. let idx = u32(local_idx / TILE_SIZE); let idy = u32(local_idx % TILE_SIZE); - let a_global_base = workgroup_id.x * TILE_SIZE; + let a_global_base = workgroup_id.x * TILE_SIZE * A_REPEAT; let b_global_base = workgroup_id.y * TILE_SIZE; let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE)); for (var vec_step:u32 = 0; vec_step < step_count; vec_step++) { workgroupBarrier(); - loadA(idx, a_global_base+idx, vec_step, idy); loadB(idx, b_global_base+idx, vec_step, idy); workgroupBarrier(); - let result = computeDotProduct(idx, idy); - tile_O[idx][idy]+=result; + for (var repeat_offset:u32=0; repeat_offset