From aa51ec8dc2f689c68a10a2b25bef93d3fa24a786 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Tue, 10 Dec 2024 14:13:41 -0800 Subject: [PATCH] Mac fix and improve comments --- .../webgpu/quantization/matmul_nbits.cc | 30 ++++++++++++++----- .../webgpu/quantization/matmul_nbits.h | 6 +--- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 581959ac706e5..be18f820e2747 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -327,13 +327,20 @@ Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); shader.AddInput("scales", ShaderUsage::UseUniform); shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + // This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion + // M is the number of rows in A and M rows in the output. + // N is the number of columns in B and N columns in the output. + // K is the hidden/shared dimension number of columns in A and K rows in B. + // Note in matmulnbits, B matrix is already transposed, however the following remains true + // for the shader below M describes A, N describes B and K is the hidden/shared dimension. + // K4/K8 are simply K divided by 4 or 8 respectively. shader.AdditionalImplementation() << R"INIT_SECTION( // Matrix dimensions and quantization parameters const TILE_SIZE : u32 = 16u; const VALUES_PER_VEC4 : u32 = 4u; const QUANTIZATION_BLOCK_SIZE : u32 = 32; -// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU, -// so we use BLOCKS_PER_CYCLE as 2u, that is process weights 2 blocks at a time. +// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM, +// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time. // This uses all 16 lanes on 12th gen intel chips. const BLOCKS_PER_CYCLE : u32 = 2u; const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE @@ -355,7 +362,8 @@ fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32) fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t { - // Since scales are output_value_t holding 1 for 32 values each, vec_step_idx jumps over 64 entries at a time. + // Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at + // a time or 2 scales at every step. let scale_offset = vec_step_idx*2; let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset); return scales[idx+scale_idx]; @@ -370,7 +378,10 @@ fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32) let idx:u32 = parallel_id; if (idx % 2 == 0) { - // Since weights are u32 holding 8 values each, vec_step_idx jumps over 64 each time. + // Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time. + // Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight + // elements were holding one element each. For the case of each element holding 8 values, begin + // would become vec_step_idx * 64/8 or vec_step_idx * 8. var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2); let b_value = input_b[b_global*uniforms.K8+weight_offset]; let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); @@ -400,8 +411,10 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t shader.MainFunctionBody() << R"MAIN_FN( // Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE // appears to give a performance win on Intel Gen12LP architecture. - // This could likley because of locality of memory access that changes with - // having idy be consecutive lanes in an EU. + // This is likley because of locality of memory access, idy below in this approach + // is the same as subgroup_id or lane id, while idx is the wave_id. + // The work distribution therefore keeps memory accesses close together in + // a single wave in this approach of indexing. let idx = u32(local_idx / TILE_SIZE); let idy = u32(local_idx % TILE_SIZE); let a_global_base = workgroup_id.x * TILE_SIZE; @@ -467,8 +480,11 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context if (use_block32 && batch_count == 1 && components_a == 4 && components_b == 4 && !has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) { - MatMulNBitsProgramPrefill program{false}; + MatMulNBitsProgramPrefill program; constexpr int32_t tile_size = 16; + // subgroup_size here controls how many elements of the hidden dimension we load in a cycle. + // MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup + // size just helps with optimal lane usage in the shader. constexpr int32_t subgroup_size = 16; program.SetWorkgroupSize(tile_size * subgroup_size); program.SetDispatchGroupSize((M + tile_size - 1) / tile_size, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index cf5724cc2c2cb..5f785c03f6a5e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -33,8 +33,7 @@ class MatMulNBitsProgram final : public Program { class MatMulNBitsProgramPrefill final : public Program { public: - MatMulNBitsProgramPrefill(bool has_zero_points) : Program{"MatMulNBitsPrefill"}, - has_zero_points_{has_zero_points} { + MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -44,9 +43,6 @@ class MatMulNBitsProgramPrefill final : public Program