Skip to content

Commit

Permalink
Implement 2d tiled matmulnbits specialized for prefill (#23058)
Browse files Browse the repository at this point in the history
### Description
This change implements matmul4bits with tiling both for A and B. This is
beneficial for prefill scenarios on Intel integrated GPUs, because each
row of A has to run through the same set of shared rows of B. This
change should improve core occupancy and model_benchmark does indicate
improvements for prefill.

The same shader is not used for generation because when A has just a
single row, the other threads in the workgroup get unused and that hurts
performance.

```
-- Baseline run on an Alderlake GPU --

C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500
Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.72338e+07
        avg (tokens/s): 29.0707                          << 
        p50 (us):       1.72548e+07
        stddev (us):    57012.8
        n:              5 * 501 token(s)
Token generation:
        avg (us):       79227.5
        avg (tokens/s): 12.6219
        p50 (us):       79284.4
        stddev (us):    2109.72
        n:              635 * 1 token(s)
Token sampling:
        avg (us):       15.8198
        avg (tokens/s): 63211.8
        p50 (us):       14.3
        stddev (us):    8.67178
        n:              640 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       27297.8
        p50 (ms):       27269.8
        stddev (ms):    89.4322
        n:              5
Peak working set size (bytes): 5490987008
WebGPU device lost (2): Device was destroyed.

----------------------------------- With Prefill Optimization ----

C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500                                                                                                                                                               
Batch size: 1, prompt tokens: 501, tokens to generate: 128
Prompt processing (time to first token):
        avg (us):       1.2135e+07
        avg (tokens/s): 41.2856                                 << 
        p50 (us):       1.21288e+07
        stddev (us):    21282.1
        n:              5 * 501 token(s)
Token generation:
        avg (us):       78945.3
        avg (tokens/s): 12.667
        p50 (us):       78900.7
        stddev (us):    2232.43
        n:              635 * 1 token(s)
Token sampling:
        avg (us):       20.5608
        avg (tokens/s): 48636.3
        p50 (us):       18.7
        stddev (us):    19.0409
        n:              640 * 1 token(s)
E2E generation (entire generation loop):
        avg (ms):       22163.8
        p50 (ms):       22160.1
        stddev (ms):    31.3122
        n:              5
Peak working set size (bytes): 5478862848
WebGPU device lost (2): Device was destroyed.
```
  • Loading branch information
sushraja-msft authored Dec 11, 2024
1 parent d8de3c4 commit 8800830
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 29 deletions.
201 changes: 172 additions & 29 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ std::string QuantizedDataType(int components) {
}
}

constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16;
} // namespace

ONNX_OPERATOR_KERNEL_EX(
Expand Down Expand Up @@ -321,6 +322,121 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
// This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion
// M is the number of rows in A and M rows in the output.
// N is the number of columns in B and N columns in the output.
// K is the hidden/shared dimension number of columns in A and K rows in B.
// Note in matmulnbits, B matrix is already transposed, however the following remains true
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
// K4/K8 are simply K divided by 4 or 8 respectively.
shader.AdditionalImplementation() << R"INIT_SECTION(
// Matrix dimensions and quantization parameters
const TILE_SIZE : u32 = 16u;
const VALUES_PER_VEC4 : u32 = 4u;
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
// This uses all 16 lanes on 12th gen intel chips.
const BLOCKS_PER_CYCLE : u32 = 2u;
const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE
const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4;
//Shared memory
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;
fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
{
if (a_global >= uniforms.M) {
return;
}
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
tile_A[slot][parallel_id] = local_A;
}
fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t
{
// Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at
// a time or 2 scales at every step.
let scale_offset = vec_step_idx*2;
let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset);
return scales[idx+scale_idx];
}
fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32)
{
if (b_global >= uniforms.N) {
return;
}
let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE));
let idx:u32 = parallel_id;
if (idx % 2 == 0)
{
// Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time.
// Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight
// elements were holding one element each. For the case of each element holding 8 values, begin
// would become vec_step_idx * 64/8 or vec_step_idx * 8.
var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2);
let b_value = input_b[b_global*uniforms.K8+weight_offset];
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale;
tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale;
tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale;
tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale;
tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale;
tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale;
tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale;
tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale;
}
}
fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
{
var sum:output_value_t = 0;
for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++)
{
sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]);
}
return sum;
}
)INIT_SECTION";

shader.MainFunctionBody() << R"MAIN_FN(
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
// appears to give a performance win on Intel Gen12LP architecture.
// This is likley because of locality of memory access, idy below in this approach
// is the same as subgroup_id or lane id, while idx is the wave_id.
// The work distribution therefore keeps memory accesses close together in
// a single wave in this approach of indexing.
let idx = u32(local_idx / TILE_SIZE);
let idy = u32(local_idx % TILE_SIZE);
let a_global_base = workgroup_id.x * TILE_SIZE;
let b_global_base = workgroup_id.y * TILE_SIZE;
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
{
workgroupBarrier();
loadA(idx, a_global_base+idx, vec_step, idy);
loadB(idx, b_global_base+idx, vec_step, idy);
workgroupBarrier();
let result = computeDotProduct(idx, idy);
tile_O[idx][idy]+=result;
}
workgroupBarrier();
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
}
)MAIN_FN";
return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -360,38 +476,65 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
context.AdapterInfo().architecture == std::string_view{"gen-12lp"} &&
block_size == 32;
const bool has_zero_points = zero_points != nullptr;
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);

if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
MatMulNBitsProgramPrefill program;
constexpr int32_t tile_size = 16;
// subgroup_size here controls how many elements of the hidden dimension we load in a cycle.
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
// size just helps with optimal lane usage in the shader.
constexpr int32_t subgroup_size = 16;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
(N + tile_size - 1) / tile_size,
1);
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{scales, ProgramTensorMetadataDependency::None}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 4)},
{static_cast<uint32_t>(K / 8)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)});
return context.RunProgram(program);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
return context.RunProgram(program);
}
return context.RunProgram(program);
}

} // namespace webgpu
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_block32_;
};

class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
public:
MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K4", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32});
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down

0 comments on commit 8800830

Please sign in to comment.