From 9267e5380501a7b6f4ec45d1dafca6cbcb400200 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 20 Dec 2024 08:22:53 +0800 Subject: [PATCH] [webgpu] Always use tile matmulnbits for block_size = 32 (#23140) ### Description After the optimization of prefill time with #23102, it seems that always using the tile matmulnibits with block_size = 32 can bring better performance even for discrete gpu for phi3 model. Phi3 becomes 42.64 tokens/sec from 32.82 tokens/sec in easy mode on my NV RTX 2000 GPU. --- .../webgpu/quantization/matmul_nbits.cc | 8 +++----- .../contrib_ops/webgpu/quantization/matmul_nbits.h | 14 ++++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 9a49adf347a29..8abcd78bfff4c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -60,7 +60,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform); const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - if ((is_intel_ || tile_m_ > 1) && block_size_ == 32) { + if (block_size_ == 32) { const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY(); const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data. const uint32_t a_length_per_tile = tile_size / a.NumComponents(); @@ -408,14 +408,12 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const uint32_t components_b = GetMaxComponents(blob_size_in_words); uint32_t components = GetMaxComponents(N); - const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} && - context.AdapterInfo().architecture == std::string_view{"gen-12lp"}; const bool has_zero_points = zero_points != nullptr; // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; - MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, is_intel}; + MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points}; if (M > kMinMForTileOptimization && block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 64; @@ -426,7 +424,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context (M + tile_m - 1) / tile_m, batch_count); program.CacheHint("T_M" + std::to_string(tile_m)); - } else if (is_intel && block_size == 32) { + } else if (block_size == 32) { components = 1; constexpr uint32_t workgroup_size = 128; const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 8a4626083419c..57615d3ddabcf 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -14,13 +14,12 @@ using namespace onnxruntime::webgpu; class MatMulNBitsProgram final : public Program { public: - MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool is_intel) : Program{"MatMulNBits"}, - output_number_{output_number}, - block_size_{block_size}, - tile_m_{tile_m}, - components_b_{components_b}, - has_zero_points_{has_zero_points}, - is_intel_{is_intel} { + MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points) : Program{"MatMulNBits"}, + output_number_{output_number}, + block_size_{block_size}, + tile_m_{tile_m}, + components_b_{components_b}, + has_zero_points_{has_zero_points} { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -32,7 +31,6 @@ class MatMulNBitsProgram final : public Program { uint32_t tile_m_; int components_b_; bool has_zero_points_; - bool is_intel_; }; class MatMulNBits final : public WebGpuKernel {