Skip to content

Commit

Permalink
Run linter
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Dec 9, 2024
1 parent 1ee552c commit ffb2dab
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 21 deletions.
25 changes: 11 additions & 14 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
}
)INIT_SECTION";

shader.MainFunctionBody() << R"MAIN_FN(
shader.MainFunctionBody() << R"MAIN_FN(
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
// appears to give a performance win on Intel Gen12LP architecture.
// This could likley because of locality of memory access that changes with
Expand Down Expand Up @@ -465,30 +465,27 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const bool has_zero_points = zero_points != nullptr;

if (use_block32 && batch_count == 1 &&
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization)
{
components_a == 4 && components_b == 4 &&
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
MatMulNBitsProgramPrefill program{false};
constexpr int32_t tile_size = 16;
constexpr int32_t subgroup_size = 16;
program.SetWorkgroupSize(tile_size * subgroup_size);
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
(N + tile_size - 1) / tile_size,
1);
(N + tile_size - 1) / tile_size,
1);
program
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{scales, ProgramTensorMetadataDependency::None}})
.AddUniformVariables({ {static_cast<uint32_t>(M)},
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K/4)},
{static_cast<uint32_t>(K/8)}} )
{static_cast<uint32_t>(K / 4)},
{static_cast<uint32_t>(K / 8)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)});
return context.RunProgram(program);
}
else
{
return context.RunProgram(program);
} else {
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.

Check warning on line 489 in onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:489: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
constexpr uint32_t output_number = 1;
Expand All @@ -498,7 +495,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
components = 1;
constexpr uint32_t workgroup_size = 128;
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
: 1;
: 1;
const uint32_t workgroup_x = workgroup_size / workgroup_y;
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
program.SetDispatchGroupSize(data_size / components / workgroup_y);
Expand Down
13 changes: 6 additions & 7 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool use_block32_;
};


class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
public:
MatMulNBitsProgramPrefill(bool has_zero_points) : Program{"MatMulNBitsPrefill"},
has_zero_points_{has_zero_points} {
has_zero_points_{has_zero_points} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K4", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32});
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K4", ProgramUniformVariableDataType::Uint32},
{"K8", ProgramUniformVariableDataType::Uint32});

private:
bool has_zero_points_;
Expand Down

0 comments on commit ffb2dab

Please sign in to comment.