Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 2d tiled matmulnbits specialized for prefill #23058

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 156 additions & 29 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
}
}

constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16;
} // namespace

ONNX_OPERATOR_KERNEL_EX(
Expand Down Expand Up @@ -321,6 +322,108 @@
return Status::OK();
}

Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
shader.AdditionalImplementation() << R"INIT_SECTION(
// Matrix dimensions and quantization parameters
const TILE_SIZE : u32 = 16u;
const VALUES_PER_VEC4 : u32 = 4u;
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU,
// so we use BLOCKS_PER_CYCLE as 2u, that is process weights 2 blocks at a time.
// This uses all 16 lanes on 12th gen intel chips.
const BLOCKS_PER_CYCLE : u32 = 2u;
const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE
const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4;

//Shared memory
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;

fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
{
if (a_global >= uniforms.M) {
return;
}
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
tile_A[slot][parallel_id] = local_A;
}

fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t
{
// Since scales are output_value_t holding 1 for 32 values each, vec_step_idx jumps over 64 entries at a time.
let scale_offset = vec_step_idx*2;
let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset);
return scales[idx+scale_idx];
}

fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32)
{
if (b_global >= uniforms.N) {
return;
}
let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE));
let idx:u32 = parallel_id;
if (idx % 2 == 0)
{
// Since weights are u32 holding 8 values each, vec_step_idx jumps over 64 each time.
var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2);
let b_value = input_b[b_global*uniforms.K8+weight_offset];
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale;
tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale;
tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale;
tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale;
tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale;
tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale;
tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale;
tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale;
}
}

fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
{
var sum:output_value_t = 0;
for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++)
{
sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]);
}
return sum;
}
)INIT_SECTION";

shader.MainFunctionBody() << R"MAIN_FN(
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
// appears to give a performance win on Intel Gen12LP architecture.
// This could likley because of locality of memory access that changes with
// having idy be consecutive lanes in an EU.
let idx = u32(local_idx / TILE_SIZE);
let idy = u32(local_idx % TILE_SIZE);
let a_global_base = workgroup_id.x * TILE_SIZE;
let b_global_base = workgroup_id.y * TILE_SIZE;
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
{
workgroupBarrier();
loadA(idx, a_global_base+idx, vec_step, idy);
loadB(idx, b_global_base+idx, vec_step, idy);
workgroupBarrier();
let result = computeDotProduct(idx, idy);
tile_O[idx][idy]+=result;
}
workgroupBarrier();
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
}
)MAIN_FN";
return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -360,38 +463,62 @@
context.AdapterInfo().architecture == std::string_view{"gen-12lp"} &&
block_size == 32;
const bool has_zero_points = zero_points != nullptr;
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);

if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
MatMulNBitsProgramPrefill program{false};
constexpr int32_t tile_size = 16;
constexpr int32_t subgroup_size = 16;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
(N + tile_size - 1) / tile_size,
1);
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{scales, ProgramTensorMetadataDependency::None}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 4)},
{static_cast<uint32_t>(K / 8)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)});
return context.RunProgram(program);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.

Check warning on line 489 in onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:489: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
return context.RunProgram(program);
}
return context.RunProgram(program);
}

} // namespace webgpu
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_block32_;
};

class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
public:
MatMulNBitsProgramPrefill(bool has_zero_points) : Program{"MatMulNBitsPrefill"},
has_zero_points_{has_zero_points} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K4", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32});

private:
bool has_zero_points_;
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down
Loading