Skip to content

Commit

Permalink
[webgpu] Optimize matmulnbits with M > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Dec 13, 2024
1 parent f43f40f commit d30cf80
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 3 deletions.
128 changes: 125 additions & 3 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,103 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
return Status::OK();
}

Status MatMulNBitsWithLargeMProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform);
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);

const uint32_t tile_m = 4;

Check warning

Code scanning / PREfast

The const variable 'tile_m' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'tile_m' can be computed at compile-time. Consider using constexpr (con.5).
const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY();
const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data.
const uint32_t a_length_per_tile = tile_size / a.NumComponents();
constexpr uint32_t block_size = 32;
const uint32_t blocks_per_tile = tile_size / block_size;
shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"
" if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n"
<< " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"
<< " } else {\n"
" return input_a_value_t(0);\n"
" }\n"
"}\n"
<< "var<workgroup> sub_a: array<array<input_a_value_t, " << a_length_per_tile << ">," << tile_m << ">;\n"
<< "var<workgroup> inter_results: array<array<array<output_value_t, " << WorkgroupSizeX() << ">, " << WorkgroupSizeY() << ">," << tile_m << ">;\n";
shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n"
<< " let row = workgroup_id.y * " << tile_m << ";\n"
<< " let batch = workgroup_id.z;\n"
" let n_blocks_per_col = uniforms.input_b_shape[1];\n"
<< " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n"
// Loop over shared dimension.
<< " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n"
<< " let a_col_start = tile * " << a_length_per_tile << ";\n"
<< " // load one tile A data into shared memory.\n"
<< " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n"
<< " let a_col = a_col_start + a_offset;\n";
for (uint32_t i = 0; i < tile_m; i++) {
shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n";
}
shader.MainFunctionBody() << " }\n"
" workgroupBarrier();\n"
// Each thread processes one block.
" let b_row = col + local_id.y;\n"
<< " let block = tile * " << blocks_per_tile << " + local_id.x;\n";
if (has_zero_points_) {
const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform);
shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n"
" let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u);\n"
" let zero_point_word_index = zero_point_byte_count >> 0x2u;\n"
" let zero_point_byte_offset = zero_point_byte_count & 0x3u;\n"
" let zero_point_nibble_offset: u32 = block & 0x1u;\n"
" let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n"
<< " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n"
<< " let zero_point = output_element_t((zero_point_word) & 0xFu);\n";
} else {
// The default zero point is 8 for unsigned 4-bit quantization.
shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n";
}
shader.MainFunctionBody() << " var scale = output_element_t(0);\n"
" var b_data = input_b_value_t(0);\n"
<< " if (block < n_blocks_per_col) {\n"
<< " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n"
<< " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
<< " }\n"
<< " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
<< " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n";
shader.MainFunctionBody() << " let b_value = b_data";
if (components_b_ > 1) {
shader.MainFunctionBody() << "[i]";
}
shader.MainFunctionBody() << ";\n"
" let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);\n"
" let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);\n"
" let b_quantized_values = mat2x4<output_element_t>(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n"
" let b_dequantized_values = (b_quantized_values - mat2x4<output_element_t>(";
for (int i = 0; i < 8; i++) {
shader.MainFunctionBody() << "zero_point";
if (i < 7) {
shader.MainFunctionBody() << ", ";
}
}
shader.MainFunctionBody() << ")) * scale;\n";
for (uint32_t i = 0; i < tile_m; i++) {
shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n";
}
shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n"
<< " }\n"
" workgroupBarrier();\n"
" }\n"
<< " if (local_id.y < " << tile_m << ") {\n"
<< " var output_value = output_value_t(0);\n"
<< " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
<< " output_value += inter_results[local_id.y][local_id.x][b];\n"
" }\n"
" if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n"
<< " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n"
<< " }\n"
" }\n";
return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -477,9 +574,34 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
block_size == 32;
const bool has_zero_points = zero_points != nullptr;

if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
if (M > 1 && components_a == 4 && block_size == 32) {
MatMulNBitsWithLargeMProgram program{gsl::narrow<int>(components_b), has_zero_points};
components = 1;
const uint32_t tile_m = 4;

Check warning

Code scanning / PREfast

The const variable 'tile_m' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'tile_m' can be computed at compile-time. Consider using constexpr (con.5).
constexpr uint32_t workgroup_size = 64;
const uint32_t workgroup_y = 8;
const uint32_t workgroup_x = workgroup_size / workgroup_y;

Check warning

Code scanning / PREfast

The const variable 'workgroup_x' can be computed at compile-time. Consider using constexpr (con.5). Warning

The const variable 'workgroup_x' can be computed at compile-time. Consider using constexpr (con.5).
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y,
(M + tile_m - 1) / tile_m,
batch_count);

TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};

program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4)},
{scales, ProgramTensorMetadataDependency::None}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)});
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
return context.RunProgram(program);
} else if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
MatMulNBitsProgramPrefill program;
constexpr int32_t tile_size = 16;
// subgroup_size here controls how many elements of the hidden dimension we load in a cycle.
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_block32_;
};

class MatMulNBitsWithLargeMProgram final : public Program<MatMulNBitsWithLargeMProgram> {
public:
MatMulNBitsWithLargeMProgram(int components_b, bool has_zero_points) : Program{"MatMulNBitsWithLargeM"},
components_b_{components_b},
has_zero_points_{has_zero_points} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

private:
int components_b_;
bool has_zero_points_;
};

class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
public:
MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} {
Expand Down

0 comments on commit d30cf80

Please sign in to comment.