Skip to content

Commit

Permalink
Improves 2d tiled matmulnbits by repeating A, loads N times for each …
Browse files Browse the repository at this point in the history
…B load
  • Loading branch information
sushraja-msft committed Dec 10, 2024
1 parent aa51ec8 commit 401938f
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,13 @@ Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const
// Note in matmulnbits, B matrix is already transposed, however the following remains true
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
// K4/K8 are simply K divided by 4 or 8 respectively.
// A_REPEAT, number of times each workgroup reloads A sharing B.
shader.AdditionalImplementation() << R"INIT_SECTION(
// Matrix dimensions and quantization parameters
const TILE_SIZE : u32 = 16u;
const VALUES_PER_VEC4 : u32 = 4u;
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
const A_REPEAT : u32 = 8u;
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
// This uses all 16 lanes on 12th gen intel chips.
Expand All @@ -349,13 +351,10 @@ const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE /
//Shared memory
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE * A_REPEAT>;
fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
{
if (a_global >= uniforms.M) {
return;
}
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
tile_A[slot][parallel_id] = local_A;
}
Expand Down Expand Up @@ -417,21 +416,36 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
// a single wave in this approach of indexing.
let idx = u32(local_idx / TILE_SIZE);
let idy = u32(local_idx % TILE_SIZE);
let a_global_base = workgroup_id.x * TILE_SIZE;
let a_global_base = workgroup_id.x * TILE_SIZE * A_REPEAT;
let b_global_base = workgroup_id.y * TILE_SIZE;
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
{
workgroupBarrier();
loadA(idx, a_global_base+idx, vec_step, idy);
loadB(idx, b_global_base+idx, vec_step, idy);
workgroupBarrier();
let result = computeDotProduct(idx, idy);
tile_O[idx][idy]+=result;
for (var repeat_offset:u32=0; repeat_offset<A_REPEAT*TILE_SIZE; repeat_offset+=TILE_SIZE)
{
let a_global = a_global_base+idx+repeat_offset;
if (a_global < uniforms.M)
{
loadA(idx, a_global_base+idx+repeat_offset, vec_step, idy);
let result = computeDotProduct(idx, idy);
tile_O[idx+repeat_offset][idy]+=result;
}
}
}
workgroupBarrier();
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
for (var a_repeat:u32=0; a_repeat<A_REPEAT; a_repeat++)
{
let ridx = a_repeat * TILE_SIZE + idx;
let a_global = a_global_base+ridx;
if (a_global < uniforms.M)
{
output[(a_global) * uniforms.N + b_global_base + idy] = tile_O[ridx][idy];
}
}
}
)MAIN_FN";
return Status::OK();
Expand Down Expand Up @@ -486,8 +500,13 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
// size just helps with optimal lane usage in the shader.
constexpr int32_t subgroup_size = 16;
// How many times each workgroup reloads A sharing B. This is tuneable,
// 8 produces a good performance for sequence length of 256/512, 16 will give
// slightly better performance for seqeengths of 1024.
// Note: This should match A_REPEAT in the shader.
constexpr unsigned int kMatMulPrefillARepeat = 8;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
program.SetDispatchGroupSize((M + (tile_size * kMatMulPrefillARepeat) - 1) / (tile_size * kMatMulPrefillARepeat),
(N + tile_size - 1) / tile_size,
1);
program
Expand Down

0 comments on commit 401938f

Please sign in to comment.