Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Oct 17, 2024
1 parent 5a0a106 commit 27043c9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
40 changes: 20 additions & 20 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
const uint32_t block_size = 32;
const uint32_t blocks_per_tile = tile_size / block_size;
shader.AdditionalImplementation() << "var<workgroup> sub_a: array<input_a_value_t, " << a_length_per_tile << ">;\n"
<< "var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX() << ">, "<< WorkgroupSizeY() << ">;\n";
<< "var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX() << ">, " << WorkgroupSizeY() << ">;\n";
std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY());
shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n"
<< " let col = output_indices[2];\n"
" let row = output_indices[1];\n"
" let batch = output_indices[0];\n"
" let n_blocks_per_col = uniforms.input_b_shape[1];\n"
" let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n"
<< " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n"
<< " // Loop over shared dimension.\n"
" for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n"
" let a_col_start = tile * " << a_length_per_tile << ";\n"
" // load one tile A data into shared memory.\n"
" for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n"
<< " let a_col_start = tile * " << a_length_per_tile << ";\n"
<< " // load one tile A data into shared memory.\n"
<< " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n"
<< " let a_col = a_col_start + a_offset;\n"
" if (a_col < uniforms.input_a_shape[2]) {\n"
" sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n"
" } else {\n"
<< " sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n"
<< " } else {\n"
" sub_a[a_offset] = input_a_value_t(0);\n"
" }\n"
" }\n"
Expand All @@ -97,14 +97,14 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
" let zero_point_nibble_offset: u32 = block & 0x1u;\n"
" let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n"
<< " let zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n"
" let zero_point = output_element_t((zero_point_word) & 0xFu);\n";
<< " let zero_point = output_element_t((zero_point_word) & 0xFu);\n";
} else {
shader.MainFunctionBody() << " // The default zero point is 8 for unsigned 4-bit quantization.\n"
" let zero_point = output_element_t(8.0);\n";
}
shader.MainFunctionBody() << " let scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n"
" let b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
" var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
<< " let b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
<< " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
<< " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n";
switch (a.NumComponents()) {
case 1:
Expand Down Expand Up @@ -139,18 +139,18 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
}
shader.MainFunctionBody() << ")) * scale;\n"
" inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n"
" word_offset += " << 8 / a.NumComponents() << ";\n"
" }\n"
<< " word_offset += " << 8 / a.NumComponents() << ";\n"
<< " }\n"
" workgroupBarrier();\n"
" }\n"
" if (local_idx < " << WorkgroupSizeY() << ") {\n"
" var output_value = output_value_t(0);\n"
" for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
" output_value += inter_results[local_idx][b];\n"
<< " if (local_idx < " << WorkgroupSizeY() << ") {\n"
<< " var output_value = output_value_t(0);\n"
<< " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
<< " output_value += inter_results[local_idx][b];\n"
" }\n"
" if (col + local_idx < uniforms.output_shape[2]) {\n"
" " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n"
" }\n"
<< " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n"
<< " }\n"
" }\n";
} else {
const std::string quantized_data_type = QuantizedDataType(a.NumComponents());
Expand Down Expand Up @@ -359,15 +359,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
if (use_block32) {
components = 1;
const uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 : 1;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);
} else {
program.SetDispatchGroupSize(data_size / components / output_number);
}


TensorShape reshaped_a_shape{batch_count, M, K / components_a};
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
TensorShape reshaped_y_shape{batch_count, M, N / components};
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ using namespace onnxruntime::webgpu;
class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
public:
MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points, bool use_block32) : Program{"MatMulNBits"},
output_number_{output_number},
components_b_{components_b},
has_zero_points_{has_zero_points},
output_number_{output_number},
components_b_{components_b},
has_zero_points_{has_zero_points},
use_block32_{use_block32} {
}

Expand Down

0 comments on commit 27043c9

Please sign in to comment.