Skip to content

Commit

Permalink
Mac fix and improve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Dec 10, 2024
1 parent ffb2dab commit aa51ec8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
30 changes: 23 additions & 7 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,20 @@ Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const
shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
// This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion
// M is the number of rows in A and M rows in the output.
// N is the number of columns in B and N columns in the output.
// K is the hidden/shared dimension number of columns in A and K rows in B.
// Note in matmulnbits, B matrix is already transposed, however the following remains true
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
// K4/K8 are simply K divided by 4 or 8 respectively.
shader.AdditionalImplementation() << R"INIT_SECTION(
// Matrix dimensions and quantization parameters
const TILE_SIZE : u32 = 16u;
const VALUES_PER_VEC4 : u32 = 4u;
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU,
// so we use BLOCKS_PER_CYCLE as 2u, that is process weights 2 blocks at a time.
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
// This uses all 16 lanes on 12th gen intel chips.
const BLOCKS_PER_CYCLE : u32 = 2u;
const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE
Expand All @@ -355,7 +362,8 @@ fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t
{
// Since scales are output_value_t holding 1 for 32 values each, vec_step_idx jumps over 64 entries at a time.
// Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at
// a time or 2 scales at every step.
let scale_offset = vec_step_idx*2;
let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset);
return scales[idx+scale_idx];
Expand All @@ -370,7 +378,10 @@ fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32)
let idx:u32 = parallel_id;
if (idx % 2 == 0)
{
// Since weights are u32 holding 8 values each, vec_step_idx jumps over 64 each time.
// Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time.
// Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight
// elements were holding one element each. For the case of each element holding 8 values, begin
// would become vec_step_idx * 64/8 or vec_step_idx * 8.
var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2);
let b_value = input_b[b_global*uniforms.K8+weight_offset];
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
Expand Down Expand Up @@ -400,8 +411,10 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
shader.MainFunctionBody() << R"MAIN_FN(
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
// appears to give a performance win on Intel Gen12LP architecture.
// This could likley because of locality of memory access that changes with
// having idy be consecutive lanes in an EU.
// This is likley because of locality of memory access, idy below in this approach
// is the same as subgroup_id or lane id, while idx is the wave_id.
// The work distribution therefore keeps memory accesses close together in
// a single wave in this approach of indexing.
let idx = u32(local_idx / TILE_SIZE);
let idy = u32(local_idx % TILE_SIZE);
let a_global_base = workgroup_id.x * TILE_SIZE;
Expand Down Expand Up @@ -467,8 +480,11 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
MatMulNBitsProgramPrefill program{false};
MatMulNBitsProgramPrefill program;
constexpr int32_t tile_size = 16;
// subgroup_size here controls how many elements of the hidden dimension we load in a cycle.
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
// size just helps with optimal lane usage in the shader.
constexpr int32_t subgroup_size = 16;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
Expand Down
6 changes: 1 addition & 5 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {

class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
public:
MatMulNBitsProgramPrefill(bool has_zero_points) : Program{"MatMulNBitsPrefill"},
has_zero_points_{has_zero_points} {
MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -44,9 +43,6 @@ class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill
{"K", ProgramUniformVariableDataType::Uint32},
{"K4", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32});

private:
bool has_zero_points_;
};

class MatMulNBits final : public WebGpuKernel {
Expand Down

0 comments on commit aa51ec8

Please sign in to comment.