Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 2d tiled matmulnbits specialized for prefill #23058

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 172 additions & 29 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
}
}

constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16;
} // namespace

ONNX_OPERATOR_KERNEL_EX(
Expand Down Expand Up @@ -321,6 +322,121 @@
return Status::OK();
}

Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
// This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion
// M is the number of rows in A and M rows in the output.
// N is the number of columns in B and N columns in the output.
// K is the hidden/shared dimension number of columns in A and K rows in B.
// Note in matmulnbits, B matrix is already transposed, however the following remains true
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
// K4/K8 are simply K divided by 4 or 8 respectively.
shader.AdditionalImplementation() << R"INIT_SECTION(
// Matrix dimensions and quantization parameters
const TILE_SIZE : u32 = 16u;
const VALUES_PER_VEC4 : u32 = 4u;
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
// This uses all 16 lanes on 12th gen intel chips.
const BLOCKS_PER_CYCLE : u32 = 2u;
const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE
const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4;

//Shared memory
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;

fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
{
if (a_global >= uniforms.M) {
return;
}
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
tile_A[slot][parallel_id] = local_A;
}

fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t
{
// Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at
// a time or 2 scales at every step.
let scale_offset = vec_step_idx*2;
let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset);
return scales[idx+scale_idx];
}

fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32)
{
if (b_global >= uniforms.N) {
return;
}
let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE));
let idx:u32 = parallel_id;
if (idx % 2 == 0)
{
// Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time.
// Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight
// elements were holding one element each. For the case of each element holding 8 values, begin
// would become vec_step_idx * 64/8 or vec_step_idx * 8.
var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2);
let b_value = input_b[b_global*uniforms.K8+weight_offset];
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale;
tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale;
tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale;
tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale;
tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale;
tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale;
tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale;
tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale;
}
}

fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
{
var sum:output_value_t = 0;
for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++)
{
sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]);
}
return sum;
}
)INIT_SECTION";

shader.MainFunctionBody() << R"MAIN_FN(
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
// appears to give a performance win on Intel Gen12LP architecture.
// This is likley because of locality of memory access, idy below in this approach
// is the same as subgroup_id or lane id, while idx is the wave_id.
// The work distribution therefore keeps memory accesses close together in
// a single wave in this approach of indexing.
let idx = u32(local_idx / TILE_SIZE);
let idy = u32(local_idx % TILE_SIZE);
let a_global_base = workgroup_id.x * TILE_SIZE;
let b_global_base = workgroup_id.y * TILE_SIZE;
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
{
workgroupBarrier();
loadA(idx, a_global_base+idx, vec_step, idy);
loadB(idx, b_global_base+idx, vec_step, idy);
workgroupBarrier();
let result = computeDotProduct(idx, idy);
tile_O[idx][idy]+=result;
}
workgroupBarrier();
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
}
)MAIN_FN";
return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -360,38 +476,65 @@
context.AdapterInfo().architecture == std::string_view{"gen-12lp"} &&
block_size == 32;
const bool has_zero_points = zero_points != nullptr;
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);

if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
MatMulNBitsProgramPrefill program;
constexpr int32_t tile_size = 16;
// subgroup_size here controls how many elements of the hidden dimension we load in a cycle.
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
// size just helps with optimal lane usage in the shader.
constexpr int32_t subgroup_size = 16;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
(N + tile_size - 1) / tile_size,
1);
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{scales, ProgramTensorMetadataDependency::None}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 4)},
{static_cast<uint32_t>(K / 8)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)});
return context.RunProgram(program);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.

Check warning on line 505 in onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:505: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};

if (use_block32) {
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
.AddUniformVariable({block_size})
.CacheHint(std::to_string(output_number));
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
return context.RunProgram(program);
}
return context.RunProgram(program);
}

} // namespace webgpu
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_block32_;
};

class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
public:
MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K4", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32});
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down
Loading