From a72c5874c8e57df782b55d10a67df14942936396 Mon Sep 17 00:00:00 2001 From: sushraja-msft <44513542+sushraja-msft@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:07:11 -0800 Subject: [PATCH] Implement 2d tiled matmulnbits specialized for prefill (#23058) ### Description This change implements matmul4bits with tiling both for A and B. This is beneficial for prefill scenarios on Intel integrated GPUs, because each row of A has to run through the same set of shared rows of B. This change should improve core occupancy and model_benchmark does indicate improvements for prefill. The same shader is not used for generation because when A has just a single row, the other threads in the workgroup get unused and that hurts performance. ``` -- Baseline run on an Alderlake GPU -- C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.72338e+07 avg (tokens/s): 29.0707 << p50 (us): 1.72548e+07 stddev (us): 57012.8 n: 5 * 501 token(s) Token generation: avg (us): 79227.5 avg (tokens/s): 12.6219 p50 (us): 79284.4 stddev (us): 2109.72 n: 635 * 1 token(s) Token sampling: avg (us): 15.8198 avg (tokens/s): 63211.8 p50 (us): 14.3 stddev (us): 8.67178 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 27297.8 p50 (ms): 27269.8 stddev (ms): 89.4322 n: 5 Peak working set size (bytes): 5490987008 WebGPU device lost (2): Device was destroyed. ----------------------------------- With Prefill Optimization ---- C:\onnxruntime>C:\model_benchmark\model_benchmark.exe -i C:\Phi-3.5-mini-instruct-onnx-web\Phi-3.5-mini-instruct-onnx-web -l 500 Batch size: 1, prompt tokens: 501, tokens to generate: 128 Prompt processing (time to first token): avg (us): 1.2135e+07 avg (tokens/s): 41.2856 << p50 (us): 1.21288e+07 stddev (us): 21282.1 n: 5 * 501 token(s) Token generation: avg (us): 78945.3 avg (tokens/s): 12.667 p50 (us): 78900.7 stddev (us): 2232.43 n: 635 * 1 token(s) Token sampling: avg (us): 20.5608 avg (tokens/s): 48636.3 p50 (us): 18.7 stddev (us): 19.0409 n: 640 * 1 token(s) E2E generation (entire generation loop): avg (ms): 22163.8 p50 (ms): 22160.1 stddev (ms): 31.3122 n: 5 Peak working set size (bytes): 5478862848 WebGPU device lost (2): Device was destroyed. ``` --- .../webgpu/quantization/matmul_nbits.cc | 201 +++++++++++++++--- .../webgpu/quantization/matmul_nbits.h | 14 ++ 2 files changed, 186 insertions(+), 29 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 31f95ee64df5d..be18f820e2747 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -39,6 +39,7 @@ std::string QuantizedDataType(int components) { } } +constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16; } // namespace ONNX_OPERATOR_KERNEL_EX( @@ -321,6 +322,121 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + // This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion + // M is the number of rows in A and M rows in the output. + // N is the number of columns in B and N columns in the output. + // K is the hidden/shared dimension number of columns in A and K rows in B. + // Note in matmulnbits, B matrix is already transposed, however the following remains true + // for the shader below M describes A, N describes B and K is the hidden/shared dimension. + // K4/K8 are simply K divided by 4 or 8 respectively. + shader.AdditionalImplementation() << R"INIT_SECTION( +// Matrix dimensions and quantization parameters +const TILE_SIZE : u32 = 16u; +const VALUES_PER_VEC4 : u32 = 4u; +const QUANTIZATION_BLOCK_SIZE : u32 = 32; +// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM, +// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time. +// This uses all 16 lanes on 12th gen intel chips. +const BLOCKS_PER_CYCLE : u32 = 2u; +const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE +const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4; + +//Shared memory +var tile_A : array, TILE_SIZE>; +var tile_B : array, TILE_SIZE>; +var tile_O : array, TILE_SIZE>; + +fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32) +{ + if (a_global >= uniforms.M) { + return; + } + let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id]; + tile_A[slot][parallel_id] = local_A; +} + +fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t +{ + // Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at + // a time or 2 scales at every step. + let scale_offset = vec_step_idx*2; + let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset); + return scales[idx+scale_idx]; +} + +fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32) +{ + if (b_global >= uniforms.N) { + return; + } + let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE)); + let idx:u32 = parallel_id; + if (idx % 2 == 0) + { + // Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time. + // Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight + // elements were holding one element each. For the case of each element holding 8 values, begin + // would become vec_step_idx * 64/8 or vec_step_idx * 8. + var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2); + let b_value = input_b[b_global*uniforms.K8+weight_offset]; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); + tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale; + tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale; + tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale; + tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale; + tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale; + tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale; + tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale; + tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale; + } +} + +fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t +{ + var sum:output_value_t = 0; + for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++) + { + sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]); + } + return sum; +} +)INIT_SECTION"; + + shader.MainFunctionBody() << R"MAIN_FN( + // Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE + // appears to give a performance win on Intel Gen12LP architecture. + // This is likley because of locality of memory access, idy below in this approach + // is the same as subgroup_id or lane id, while idx is the wave_id. + // The work distribution therefore keeps memory accesses close together in + // a single wave in this approach of indexing. + let idx = u32(local_idx / TILE_SIZE); + let idy = u32(local_idx % TILE_SIZE); + let a_global_base = workgroup_id.x * TILE_SIZE; + let b_global_base = workgroup_id.y * TILE_SIZE; + let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE)); + for (var vec_step:u32 = 0; vec_step < step_count; vec_step++) + { + workgroupBarrier(); + loadA(idx, a_global_base+idx, vec_step, idy); + loadB(idx, b_global_base+idx, vec_step, idy); + workgroupBarrier(); + let result = computeDotProduct(idx, idy); + tile_O[idx][idy]+=result; + } + workgroupBarrier(); + if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) { + output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy]; + } +)MAIN_FN"; + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -360,38 +476,65 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context context.AdapterInfo().architecture == std::string_view{"gen-12lp"} && block_size == 32; const bool has_zero_points = zero_points != nullptr; - // TODO: Support output_number > 1. Some cases are failed when output_number > 1. - // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; - constexpr uint32_t output_number = 1; - MatMulNBitsProgram program{output_number, gsl::narrow(components_b), has_zero_points, use_block32}; - - if (use_block32) { - components = 1; - constexpr uint32_t workgroup_size = 128; - const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 - : 1; - const uint32_t workgroup_x = workgroup_size / workgroup_y; - program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); - program.SetDispatchGroupSize(data_size / components / workgroup_y); + + if (use_block32 && batch_count == 1 && + components_a == 4 && components_b == 4 && + !has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) { + MatMulNBitsProgramPrefill program; + constexpr int32_t tile_size = 16; + // subgroup_size here controls how many elements of the hidden dimension we load in a cycle. + // MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup + // size just helps with optimal lane usage in the shader. + constexpr int32_t subgroup_size = 16; + program.SetWorkgroupSize(tile_size * subgroup_size); + program.SetDispatchGroupSize((M + tile_size - 1) / tile_size, + (N + tile_size - 1) / tile_size, + 1); + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 4)}, + {static_cast(K / 8)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}); + return context.RunProgram(program); } else { - program.SetDispatchGroupSize(data_size / components / output_number); - } + // TODO: Support output_number > 1. Some cases are failed when output_number > 1. + // const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1; + constexpr uint32_t output_number = 1; + MatMulNBitsProgram program{output_number, gsl::narrow(components_b), has_zero_points, use_block32}; + + if (use_block32) { + components = 1; + constexpr uint32_t workgroup_size = 128; + const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4 + : 1; + const uint32_t workgroup_x = workgroup_size / workgroup_y; + program.SetWorkgroupSize(workgroup_x, workgroup_y, 1); + program.SetDispatchGroupSize(data_size / components / workgroup_y); + } else { + program.SetDispatchGroupSize(data_size / components / output_number); + } + + TensorShape reshaped_a_shape{batch_count, M, K / components_a}; + TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; + TensorShape reshaped_y_shape{batch_count, M, N / components}; - TensorShape reshaped_a_shape{batch_count, M, K / components_a}; - TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b}; - TensorShape reshaped_y_shape{batch_count, M, N / components}; - - program - .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, - {scales, ProgramTensorMetadataDependency::None}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) - .AddUniformVariable({block_size}) - .CacheHint(std::to_string(output_number)); - if (has_zero_points) { - program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + program + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow(components_a)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)}, + {scales, ProgramTensorMetadataDependency::None}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(components)}) + .AddUniformVariable({block_size}) + .CacheHint(std::to_string(output_number)); + if (has_zero_points) { + program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); + } + return context.RunProgram(program); } - return context.RunProgram(program); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index c0d6b3e6379cd..5f785c03f6a5e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -31,6 +31,20 @@ class MatMulNBitsProgram final : public Program { bool use_block32_; }; +class MatMulNBitsProgramPrefill final : public Program { + public: + MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K4", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}); +}; + class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {