From 379dffa65b278e11d60193745211ff2d3e9a2a78 Mon Sep 17 00:00:00 2001 From: Your Date: Fri, 22 Mar 2024 23:14:31 +0000 Subject: [PATCH] lint --- .../cutlass_extensions/compute_occupancy.h | 61 +- .../epilogue/thread/fused_activations.h | 64 +- .../epilogue_per_row_per_col_scale.h | 493 +++++------ .../threadblock/epilogue_tensor_op_int32.h | 245 +++--- .../moe/cutlass_extensions/epilogue_helpers.h | 42 +- .../gemm/device/gemm_universal_base_compat.h | 483 +++++------ .../gemm/device/splitk_gemm_grouped.h | 773 ++++++++--------- .../gemm/kernel/default_int8_traits.h | 36 +- .../gemm/kernel/default_splitk_gemm_grouped.h | 69 +- .../gemm/kernel/fpA_intB_gemm.h | 793 ++++++++---------- .../gemm/kernel/moe_cutlass_kernel.h | 662 +++++++-------- .../gemm/kernel/moe_problem_visitor.h | 487 +++++------ .../gemm/kernel/splitk_gemm_grouped.h | 671 +++++++-------- .../warp/mma_tensorop_compute_B_with_f16.h | 300 +++---- .../moe/cutlass_extensions/gemm_configs.h | 162 ++-- .../interleaved_numeric_conversion.h | 667 +++++++-------- .../tile_interleaved_layout.h | 27 +- .../fine_grained_scale_zero_iterator.h | 335 ++++---- .../cutlass_extensions/weight_only_quant_op.h | 26 +- .../cuda/moe/ft_moe/cutlass_heuristic.h | 6 +- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 63 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 769 ++++++++--------- .../transformers/test_parity_mixtral_moe.py | 6 +- .../python/transformers/test_parity_moe.py | 4 +- 24 files changed, 3353 insertions(+), 3891 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h index e2ca46d609d9f..1239d04916047 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h @@ -24,40 +24,41 @@ using namespace onnxruntime; namespace ort_fastertransformer { -template inline int compute_occupancy_for_kernel() { - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) { - cudaFuncAttributes attr; - int device = 0; - int max_smem_per_block = 0; - CUDA_CALL_THROW(cudaGetDevice(&device)); - CUDA_CALL_THROW(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - if constexpr (enable_cutlass_3x) { - CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); - } else { - CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); - } - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) - // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this - // configuration. - return 0; - } - } +template +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - int max_active_blocks = -1; + if (smem_size > (48 << 10)) { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); if constexpr (enable_cutlass_3x) { - CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::device_kernel, - 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); } else { - CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::Kernel, - GemmKernel::kThreadCount, smem_size)); + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } else { + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::Kernel, + GemmKernel::kThreadCount, smem_size)); + } - return max_active_blocks; + return max_active_blocks; } -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h index f3c622b88a5fb..5dd2d3ffa5c54 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h @@ -46,60 +46,50 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace epilogue -{ -namespace thread -{ +namespace cutlass { +namespace epilogue { +namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; } -__forceinline__ __device__ float tanh_opt(float x) -{ +__forceinline__ __device__ float tanh_opt(float x) { #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - float const exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #else - return fast_tanh(x); + return fast_tanh(x); #endif } ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -struct GELU_taylor -{ - static bool const kIsHeavy = true; +struct GELU_taylor { + static bool const kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const - { + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); + return float(cutlass::constants::half() * z * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } - return float(cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } + using Params = LinearCombinationGenericParams; - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { + return this->operator()(scalar); + } }; -} // namespace thread -} // namespace epilogue -} // namespace cutlass +} // namespace thread +} // namespace epilogue +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h index d3d4d0a45ab29..7d2d2e50004ca 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -48,305 +48,244 @@ namespace tk = tensorrt_llm::common; -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ +namespace cutlass { +namespace epilogue { +namespace threadblock { template -class EpilogueVisitorPerRowPerCol -{ -public: - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using ScaleTileIterator = ScaleTileIterator_; - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using AlphaScaleElementType = typename ScaleTileIterator::Element; - - using ElementCompute = ElementCompute_; - using AccumulatorFragment = Array; - using ComputeFragment = Array; - using OutputVector = Array; - - static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; - static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); - - /// Argument structure - struct Arguments - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - Arguments() - : batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_) - : elementwise(elementwise_) - , batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, - int64_t batch_stride_C_, int64_t batch_stride_D_) - : elementwise(elementwise_) - , batch_stride_alpha(batch_stride_alpha_) - , batch_stride_C(batch_stride_C_) - , batch_stride_D(batch_stride_D_) - { - } - }; - - struct Params - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args) - : elementwise(args.elementwise) - , batch_stride_alpha(args.batch_stride_alpha) - , batch_stride_C(args.batch_stride_C) - , batch_stride_D(args.batch_stride_D) - { - } - }; - - /// Shared storage - struct SharedStorage - { - }; - -private: - Params const& params_; - SharedStorage& shared_storage_; - MatrixCoord extent_; - MatrixCoord extent_real_; - ElementwiseFunctor elementwise_; - - bool const per_token_quant_; - bool const per_channel_quant_; - - AlphaScaleElementType* ptr_alpha_row_; - AlphaScaleElementType* ptr_alpha_col_; - ScaleTileIterator iterator_alpha_col_; - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - - AlphaScaleElementType element_alpha_row_ = 1.0f; - AlphaScaleElementType element_alpha_col_ = 1.0f; - typename ScaleTileIterator::Fragment fragment_alpha_col_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator beta_; - - int column_offset_; - - MatrixCoord thread_offset_; - -public: - CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) - : params_(params) - , shared_storage_(shared_storage) - , extent_(problem_size) - , elementwise_(params.elementwise) - , per_token_quant_(quant_option.hasPerTokenScaling()) - , per_channel_quant_(quant_option.hasPerChannelScaling()) - , ptr_alpha_row_(ptr_alpha_row) - , ptr_alpha_col_(ptr_alpha_col) - , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) - , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) - , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) - , extent_real_(problem_size_real) - { - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) - { - iterator_C_.clear_mask(); - } - - if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) - { - element_alpha_col_ = *ptr_alpha_col_; - } - - if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) - { - element_alpha_row_ = *ptr_alpha_row_; - } + typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false> +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() + : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) { } - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) - { ///< Total number of split-K slices + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) { } - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) - { - iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, + int64_t batch_stride_C_, int64_t batch_stride_D_) + : elementwise(elementwise_), batch_stride_alpha(batch_stride_alpha_), batch_stride_C(batch_stride_C_), batch_stride_D(batch_stride_D_) { } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() - { - if (per_channel_quant_) - { - iterator_alpha_col_.load(fragment_alpha_col_); - } + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise), batch_stride_alpha(args.batch_stride_alpha), batch_stride_C(args.batch_stride_C), batch_stride_D(args.batch_stride_D) { } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) - { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) - { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } + }; + + /// Shared storage + struct SharedStorage { + }; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params), shared_storage_(shared_storage), extent_(problem_size), elementwise_(params.elementwise), per_token_quant_(quant_option.hasPerTokenScaling()), per_channel_quant_(quant_option.hasPerChannelScaling()), ptr_alpha_row_(ptr_alpha_row), ptr_alpha_col_(ptr_alpha_col), iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), extent_real_(problem_size_real) { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); } - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) - { - // load alpha_row in begin_step only when per token(row) scaling is used - if (per_token_quant_) - { - int thread_offset_row - = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); - - arch::global_load( - element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); - } + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; } - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) - { - - NumericArrayConverter source_converter; - - ComputeFragment result = source_converter(accum); - if (per_channel_quant_) - { - ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; - result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); - } - else - { - result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); - } - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; } - - /// Called at the end of a row - CUTLASS_DEVICE - void end_row(int row_idx) {} - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) - { - - iterator_D_.store(fragment_D_); - ++iterator_D_; + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); } + } - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() {} + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); -private: - CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_( - ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) - { + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col[i] * scale_row); - } + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); - return result; + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) { + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); } - CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_( - ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) - { + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col * scale_row); - } + return result; + } - return result; + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); } + + return result; + } }; -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h index 6f26d79017034..88d09634888a1 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -80,35 +80,28 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ +namespace cutlass { +namespace epilogue { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// -namespace detail -{ +namespace detail { /// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. template struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; + ThreadMap> { + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed; - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; - static int const kFragmentsPerIteration = 2; + static int const kFragmentsPerIteration = 2; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace detail +} // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -116,167 +109,139 @@ struct DefaultIteratorsTensorOp -class SharedLoadIteratorMixed -{ -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; +template +class SharedLoadIteratorMixed { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; - using Element = int32_t; + using Element = int32_t; - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; - static int const kThreads = ThreadMap::kThreads; + static int const kThreads = ThreadMap::kThreads; - /// Fragment object - using Fragment = Array; + /// Fragment object + using Fragment = Array; - /// Memory access size - using AccessType = AlignedArray; + /// Memory access size + using AccessType = AlignedArray; - /// Vector type used for SMEM loads - using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), - const_min(16, kAlignment)>; + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; - static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; -private: - // - // Data members - // + private: + // + // Data members + // - /// Byte-level pointer - LoadType const* pointers_[kLoadsPerAccess]; + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; - /// Stride along adjacent rows in units of LoadType - int stride_; + /// Stride along adjacent rows in units of LoadType + int stride_; -public: - // - // Methods - // + public: + // + // Methods + // - /// Constructor - CUTLASS_DEVICE - SharedLoadIteratorMixed(TensorRef ref, int thread_idx) - : stride_((ref.stride(0) / LoadType::kElements)) - { + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); - // Initialize pointers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] = reinterpret_cast(ref.data()); - - int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; - col_idx += (bank_offset + i) % kLoadsPerAccess; + col_idx += (bank_offset + i) % kLoadsPerAccess; - pointers_[i] += thread_offset.row() * stride_ + col_idx; - } + pointers_[i] += thread_offset.row() * stride_ + col_idx; } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] += pointer_offset / LoadType::kElements; - } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; } + } - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] - += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; - } + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const - { - + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) - { + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements; - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) - { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) - { - - int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ - + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ - + pointer_offset / LoadType::kElements; - - int frag_row_idx - = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - LoadType* frag_ptr = reinterpret_cast(&frag); + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) - { + LoadType* frag_ptr = reinterpret_cast(&frag); - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kLoadsPerAccess; ++v) - { - - int vector_idx - = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); - LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; - } - } - } + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; } + } } + } } + } - /// Loads a fragment - CUTLASS_DEVICE - void load(Fragment& frag) const - { - - load_with_pointer_offset(frag, 0); - } + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_pointer_offset(frag, 0); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h index 24db386d31148..815b56ca842ea 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h @@ -133,56 +133,56 @@ constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScali template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationSilu; + using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationRelu; + using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationGeneric< - cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, - ElementAccumulator, BiasScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, BiasScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; + using Op = cutlass::epilogue::thread::LinearCombination; }; constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; template struct Epilogue { - using Op = - cutlass::epilogue::thread::LinearCombinationSilu; + using Op = + cutlass::epilogue::thread::LinearCombinationSilu; }; template struct Epilogue { - using Op = - cutlass::epilogue::thread::LinearCombinationRelu; + using Op = + cutlass::epilogue::thread::LinearCombinationRelu; }; template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationGeneric< - cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, - ElementAccumulator, DefaultScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, DefaultScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; + using Op = cutlass::epilogue::thread::LinearCombination; }; -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h index 2edd5a228b470..812ada147144c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -54,12 +54,9 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace device -{ +namespace cutlass { +namespace gemm { +namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -73,366 +70,314 @@ namespace device */ template -class GemmUniversalBaseCompat -{ -public: - using GemmKernel = GemmKernel_; - using ThreadblockShape = typename GemmKernel::Mma::Shape; +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; - /// Argument structure - using Arguments = typename GemmKernel::Arguments; + /// Argument structure + using Arguments = typename GemmKernel::Arguments; -protected: - /// Kernel parameters object - typename GemmKernel::Params params_; + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; -protected: - /// Private helper to obtain the grid dimensions with fix-up for split-K - static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) - { + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + gemm_k_size = args.problem_size.k(); - gemm_k_size = args.problem_size.k(); + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } } + } -public: - /// Constructs the GEMM. - GemmUniversalBaseCompat() {} - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - ThreadblockSwizzle threadblock_swizzle; - dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) - { + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - return Status::kErrorInvalidProblem; - } - - return GemmKernel::can_implement(args); + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; } - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - - size_t workspace_bytes = 0; + return GemmKernel::can_implement(args); + } - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + size_t workspace_bytes = 0; - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) - { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - // Split-K parallel always requires a temporary workspace - workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); - } - else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) - { - - // Serial split-K only requires a temporary workspace if the number of partitions along the - // GEMM K dimension is greater than one. - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - return workspace_bytes; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); } - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); - ThreadblockSwizzle threadblock_swizzle; + return workspace_bytes; + } - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + ThreadblockSwizzle threadblock_swizzle; - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - return result; - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); - int max_active_blocks = -1; - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + return result; + } - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); - if (smem_size <= (48 << 10)) - { + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); - - if (result == cudaSuccess) - { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else - { + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); - if (result != cudaSuccess) - { + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - if (smem_capacity < 0) - { - int device_idx = 0; - result = cudaGetDevice(&device_idx); + return -1; + } - if (result != cudaSuccess) - { - return -1; - } + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); + if (result != cudaSuccess) { + return -1; + } - if (result != cudaSuccess) - { - return -1; - } + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } + if (result != cudaSuccess) { + return -1; + } - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - return occupancy; - } + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - CUTLASS_TRACE_HOST(" returning internal error"); - - return -1; + return occupancy; } - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + CUTLASS_TRACE_HOST(" returning internal error"); - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + return -1; + } - size_t workspace_bytes = get_workspace_size(args); + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + size_t workspace_bytes = get_workspace_size(args); - if (workspace_bytes) - { + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - if (!workspace) - { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - return Status::kErrorWorkspaceNull; - } + return Status::kErrorWorkspaceNull; + } - if (args.mode == GemmUniversalMode::kGemm) - { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } + return Status::kErrorInternal; } + } + } - // Get CUDA grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - // Initialize the Params structure - params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; + if (result != cudaSuccess) { + return Status::kErrorInternal; + } } - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + return Status::kSuccess; + } - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); - params_.update(args, workspace); + size_t workspace_bytes = get_workspace_size(args); - return Status::kSuccess; + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + params_.update(args, workspace); - // - // Configure grid and block dimensions - // + return Status::kSuccess; + } - ThreadblockSwizzle threadblock_swizzle; + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(GemmKernel::kThreadCount, 1, 1); + // + // Configure grid and block dimensions + // - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + ThreadblockSwizzle threadblock_swizzle; - // - // Launch kernel - // + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); - CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - // Launch - cutlass::Kernel<<>>(params_); + // + // Launch kernel + // - // - // Query for errors - // - cudaError_t result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } + // Launch + cutlass::Kernel<<>>(params_); - return Status::kSuccess; - } + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; } - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + return Status::kSuccess; + } - Status status = initialize(args, workspace, stream); + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } - if (status == Status::kSuccess) - { - status = run(stream); - } + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); - return status; + if (status == Status::kSuccess) { + status = run(stream); } + + return status; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h index bfd3666b9c189..7aa70121e2483 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -55,488 +55,425 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace device -{ +namespace cutlass { +namespace gemm { +namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, - int64_t* splitk_buffer_offsets) -{ - // in_tensor: [problem_idx, k_partition, hidden_size] - // Note that different requests of in_tensor might have different hidden_size (=m*n) - // so, we need to use splitk_buffer_offsets. - // out_tensor: problem_idx * [hidden_size] - - int const problem_idx = blockIdx.y; - GemmCoord problem = problem_sizes[problem_idx]; - int const hidden_size = problem.m() * problem.n(); - const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; - T_OUT* out_tensor_ = out_tensor[problem_idx]; - - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) - { - float sum = 0.0f; - for (int k_idx = 0; k_idx < splitk; k_idx++) - { - sum += (float) in_tensor_[k_idx * hidden_size + i]; - } - out_tensor_[i] = (T_OUT) (sum); + int64_t* splitk_buffer_offsets) { + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different hidden_size (=m*n) + // so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) { + sum += (float)in_tensor_[k_idx * hidden_size + i]; } + out_tensor_[i] = (T_OUT)(sum); + } } /// GEMM Grouped template -class BaseSplitkGrouped -{ -public: - using BaseKernel = BaseKernel_; - - using ElementA = typename BaseKernel::ElementA; - using LayoutA = typename BaseKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = BaseKernel::kTransformA; - static int const kAlignmentA = BaseKernel::kAlignmentA; - - using ElementB = typename BaseKernel::ElementB; - using LayoutB = typename BaseKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = BaseKernel::kTransformB; - static int const kAlignmentB = BaseKernel::kAlignmentB; - - using ElementC = typename BaseKernel::ElementC; - using LayoutC = typename BaseKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - static int const kAlignmentC = BaseKernel::kAlignmentC; - - using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - using Operator = typename BaseKernel::Operator; - using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - using ThreadblockShape = typename BaseKernel::Mma::Shape; - using WarpShape = typename BaseKernel::WarpShape; - using InstructionShape = typename BaseKernel::InstructionShape; - static int const kStages = BaseKernel::Mma::kStages; - - /// Argument structure - using Arguments = typename BaseKernel::Arguments; - - using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; - -protected: - /// Kernel parameters object - typename BaseKernel::Params gemm_params_; - -private: - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) - { - int32_t tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; - BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); - tiles += problem_tile_count(problem); - } - return tiles; +class BaseSplitkGrouped { + public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + + protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; } - /// Copy from `data` to `workspace` - Status copy_to_workspace(void* workspace, void* data, size_t bytes) - { - cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); - if (cuda_error != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - cuda_error = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); - return Status::kErrorInternal; - } - - return Status::kSuccess; + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, int32_t tile_count, void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute( + args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*)host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; } - /// Precomputes scheduling information for the grouped GEMM - Status precompute(Arguments const& args, int32_t tile_count, void* workspace) - { - size_t workspace_bytes = get_workspace_size(args); - std::vector host_workspace(workspace_bytes); - BaseKernel::ProblemVisitor::host_precompute( - args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); - return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; } - /// Reorder `data` according to `indices` - template - static void reorder_array(T* data, std::vector const& indices) - { - // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) - { - copy.at(i) = data[indices[i]]; - } - - memcpy(data, copy.data(), indices.size() * sizeof(T)); + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) { + total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); } + size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; -public: - /// Constructs the GEMM. - BaseSplitkGrouped() {} + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } + return workSpaceSize; + } - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - return BaseKernel::can_implement(args); + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } } - /// Get the number of tiles in a problem - static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) - { - auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); - return BaseKernel::ProblemVisitor::tile_count(grid); + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; } - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(Arguments const& args) - { - if (args.host_problem_sizes == nullptr) - { - CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); - return -1; - } + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, + int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); + return 0; + } - return group_tile_count(args.host_problem_sizes, args.problem_count); + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); + return 0; } - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - size_t total_mn = 0; - for (int i = 0; i < args.problem_count; i++) - { - total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); - } - size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( - args.host_problem_sizes, args.problem_count, args.threadblock_count); - } - return workSpaceSize; + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) { + available_sm_count = multiprocessor_count; } - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } - return dim3(args.threadblock_count, 1, 1); + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; } - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; } - /// Sorts each pointer passed in according to the indices that sort - /// `problem_sizes_ptr` in descending order of problem-K dimension. - static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, - int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, - int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) - { - std::vector indices(problem_count); - std::iota(indices.begin(), indices.end(), 0); - std::stable_sort(indices.begin(), indices.end(), - [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); - - reorder_array(problem_sizes_ptr, indices); - reorder_array(lda_host_ptr, indices); - reorder_array(ldb_host_ptr, indices); - reorder_array(ldc_host_ptr, indices); - reorder_array(ldd_host_ptr, indices); - reorder_array(offset_A_ptr, indices); - reorder_array(offset_B_ptr, indices); - reorder_array(offset_C_ptr, indices); - reorder_array(offset_D_ptr, indices); + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Computes the number of threadblocks to launch for the grouped kernel - static int sufficient( - cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) - { - // Determine the number of blocks that would be launched to fill up a single - // wave on the GPU with each SM having maximum occupancy. - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); - return 0; - } - - int multiprocessor_count; - result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); - return 0; - } - - bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); - if (override_sm_count) - { - available_sm_count = multiprocessor_count; - } - - int max_active_blocks = maximum_active_blocks(); - if (max_active_blocks <= 0) - { - return 0; - } - - int occupancy_based_block_count = available_sm_count * max_active_blocks; - - if (problem_sizes_ptr == nullptr || problem_count == 0) - { - return occupancy_based_block_count; - } - - int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - - // If the group contains a single problem, launching the exact number of - // threadblocks needed to cover the problem minimizes the work performed - // per threadblock in finding the next tile to compute. We return total_tiles - // unless the user has provided the SM count. - if (problem_count == 1 && override_sm_count) - { - return total_tiles; - } - - // Choose between the full wave of threadblocks and the tile count. If there - // are fewer tiles in the group than threadblocks in the full wave, only - // some threadblocks will be assigned tiles. Those threadblocks - // which are not assigned tiles still need to perform the work of iterating through - // problem sizes to determine that they have no work to do. This competes for cycles - // with those threadblocks that are assigned tiles to compute. - return std::min(total_tiles, occupancy_based_block_count); + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + gemm_params_ = typename BaseKernel::Params(args, workspace); } - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Workspace - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); - } - else - { - gemm_params_ = typename BaseKernel::Params(args, workspace); - } - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; + if (result != cudaSuccess) { + return Status::kErrorInternal; + } } - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_.update(args, workspace, tile_count); - } - else - { - gemm_params_.update(args, workspace); - } - - return Status::kSuccess; + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - if (!gemm_params_.problem_visitor.problem_count) - { - return Status::kSuccess; - } - - // - // Launch kernel - // - - // Launch splitk grouped gemm - { - dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); - dim3 block(BaseKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - cutlass::Kernel<<>>(gemm_params_); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - // Launch splitkReduction - { - dim3 grid(32, gemm_params_.problem_visitor.problem_count); - dim3 block(256); - splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, - gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, - gemm_params_.splitk_buffer_offsets); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - return Status::kSuccess; + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } else { + gemm_params_.update(args, workspace); } - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + if (!gemm_params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm { - return run(stream); + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel<<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } } - /// Initializes and runs the kernel. - Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) + // Launch splitkReduction { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, + gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } - Status status = initialize(args, workspace, stream); + return Status::kSuccess; + } - if (status == Status::kSuccess) - { - status = run(stream); - } + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } - return status; + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); } + + return status; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM Grouped template -class SplitkGemmGrouped : public BaseSplitkGrouped -{ -public: - using GemmKernel = GemmKernel_; +class SplitkGemmGrouped : public BaseSplitkGrouped { + public: + using GemmKernel = GemmKernel_; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace gemm -} // namespace cutlass +} // namespace device +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h index 3fd722994e296..fe4bc0940d9e8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -22,36 +22,30 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/layout/matrix.h" -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { template -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassSimt; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; }; // ======================= Turing Traits ============================== template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; }; // ======================= Ampere Traits ============================== template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h index 1dbd0b1765fbb..57c4bcadef0cc 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h @@ -59,12 +59,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -111,7 +108,7 @@ template < GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, /// Operation performed by GEMM typename Operator = typename device::DefaultGemmConfiguration::Operator, + ElementC_, ElementAccumulator>::Operator, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Permute result D @@ -169,39 +166,37 @@ template < /// Permute result D typename PermuteDLayout> struct DefaultSplitkGemmGrouped::value>::type> -{ - - // If true, we must construct a 'transposed-and-exchanged' Mma operator. - static bool const kInternalTranspose = platform::is_same::value; - - using MapArguments = kernel::detail::MapArguments; - - // Define the default GEMM kernel - using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; - - /// Define the kernel in terms of the default kernel - using GemmKernel = kernel::SplitkGemmGrouped; + ComplexTransform::kNone, // transform A + kAlignmentA, ElementB, LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, GroupScheduleMode_, Operator, SharedMemoryClear, + PermuteDLayout, typename platform::enable_if::value>::type> { + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments; + + // Define the default GEMM kernel + using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::SplitkGemmGrouped; }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index 053f73103658d..8cd48e6b6e1c9 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -46,522 +46,431 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace detail -{ +namespace detail { template inline constexpr bool dependent_false_v = false; } -template -struct GemmFpAIntB -{ - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Element; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - - /// Parameters structure - struct Arguments - { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - - cutlass::gemm::GemmCoord problem_size; - int group_size; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - - // Control serial split-k - int batch_count; - - typename EpilogueOutputOp::Params output_op; - - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // Included so we can use Gemm Universal - int batch_stride_D = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, - typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, - typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), - int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, - int const* scatter_D_indices = nullptr) - : problem_size(problem_size) - , group_size(group_size) - , ref_A(ref_A) - , ref_B(ref_B) - , ref_scale(ref_scale) - , ref_zero(ref_zero) - , ref_C(ref_C) - , ref_D(ref_D) - , batch_count(serial_split_k_factor) - , output_op(output_op) - , gather_A_indices(gather_A_indices) - , gather_B_indices(gather_B_indices) - , scatter_D_indices(scatter_D_indices) - { - } - }; - - /// Parameters structure - struct Params - { - cutlass::gemm::GemmCoord problem_size; - int group_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::Params params_scale; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename EpilogueOutputOp::Params output_op; - int* semaphore; - int gemm_k_size; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , semaphore(0) - , gemm_k_size(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, - void* workspace = nullptr) - : problem_size(args.problem_size) - , group_size(args.group_size) - , grid_tiled_shape(grid_tiled_shape) - , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) - , params_A(args.ref_A.layout()) - , ref_A(args.ref_A) - , params_B(args.ref_B.layout()) - , ref_B(args.ref_B) - , params_scale(args.ref_scale.layout()) - , ref_scale(args.ref_scale) - , ref_zero(args.ref_zero) - , params_C(args.ref_C.layout()) - , ref_C(args.ref_C) - , params_D(args.ref_D.layout()) - , ref_D(args.ref_D) - , output_op(args.output_op) - , semaphore(static_cast(workspace)) - , gemm_k_size(gemm_k_size) - , gather_A_indices(args.gather_A_indices) - , gather_B_indices(args.gather_B_indices) - , scatter_D_indices(args.scatter_D_indices) - { - } - }; - - /// Shared memory storage structure - union SharedStorage - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; // // Methods // CUTLASS_HOST_DEVICE - GemmFpAIntB() {} + Arguments() {} - /// Determines whether kernel satisfies alignment CUTLASS_HOST_DEVICE - static Status can_implement(Arguments const& args) - { - static int const kAlignmentA - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; - - static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; - - static int const kAlignmentC = (platform::is_same>::value) - ? 32 - : (platform::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; - - if (!TensorRef_aligned(args.ref_A, kAlignmentA)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_B, kAlignmentB)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_C, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_D, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!args.ref_scale.good()) - { - return Status::kErrorNotSupported; - } - - if constexpr (hasZero(Mma::QuantOp)) - { - if (!args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - else - { - if (args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - - if constexpr (isFinegrained(Mma::QuantOp)) - { - if (args.group_size != 64 && args.group_size != 128) - { - return Status::kErrorNotSupported; - } - } - - return Status::kSuccess; + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), group_size(group_size), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), ref_zero(ref_zero), ref_C(ref_C), ref_D(ref_D), batch_count(serial_split_k_factor), output_op(output_op), gather_A_indices(gather_A_indices), gather_B_indices(gather_B_indices), scatter_D_indices(scatter_D_indices) { } + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { + // + // Methods + // - return 0; + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } - // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator - // has a different constructor signature than a regular cutlass iterator - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), group_size(args.group_size), grid_tiled_shape(grid_tiled_shape), swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(args.ref_A.layout()), ref_A(args.ref_A), params_B(args.ref_B.layout()), ref_B(args.ref_B), params_scale(args.ref_scale.layout()), ref_scale(args.ref_scale), ref_zero(args.ref_zero), params_C(args.ref_C.layout()), ref_C(args.ref_C), params_D(args.ref_D.layout()), ref_D(args.ref_D), output_op(args.output_op), semaphore(static_cast(workspace)), gemm_k_size(gemm_k_size), gather_A_indices(args.gather_A_indices), gather_B_indices(args.gather_B_indices), scatter_D_indices(args.scatter_D_indices) { } - - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const& args) { + static int const kAlignmentA = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; } - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } - return; - } + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } - typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; - typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; - cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } - // Compute position within threadblock - int thread_idx = threadIdx.x; + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, - params.gather_B_indices); + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; - typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; - typename Mma::IteratorScale iterator_scale = initialize_scale( - params.params_scale, params.ref_scale.data(), params.ref_zero.data(), - {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + // Compute position within threadblock + int thread_idx = threadIdx.x; - typename Mma::FragmentC accumulators; + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); - accumulators.clear(); + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); - if (!kSplitKSerial || gemm_k_iterations > 0) - { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); - // - // Epilogue - // + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; - EpilogueOutputOp output_op(params.output_op); + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - // - // Masked tile iterators constructed from members - // + typename Mma::FragmentC accumulators; - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + accumulators.clear(); - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + // + // Epilogue + // - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); + EpilogueOutputOp output_op(params.output_op); - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { + // + // Masked tile iterators constructed from members + // - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) - { - iterator_C = iterator_D; - } + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); - semaphore.wait(threadblock_tile_offset.k()); - } + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - // - // Release the semaphore - // + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { + semaphore.wait(threadblock_tile_offset.k()); + } - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) - { + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else - { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } + // + // Release the semaphore + // - semaphore.release(lock); - } + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 900) - CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. #else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); #endif #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index 54b077ff72f9b..4d1cad003d6d7 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -44,405 +44,407 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// // This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. // It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. -template using void_t = void; +template +using void_t = void; -template struct use_dq_gemm : platform::false_type {}; +template +struct use_dq_gemm : platform::false_type {}; -template struct use_dq_gemm> : platform::true_type {}; +template +struct use_dq_gemm> : platform::true_type {}; ///////////////////////////////////////////////////////////////////////////////////////////////// -template struct MoeFCGemm { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = false; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments< - typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout, Mma::kTransformA, - Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element, typename Mma::IteratorB::Layout, - Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC, kTransposed>; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - static_assert(!kTransposed, "Transpose problem not supported"); - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - using ElementScale = ElementC; + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments< + typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout, Mma::kTransformA, + Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element, typename Mma::IteratorB::Layout, + Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC, kTransposed>; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; + int problem_count; + int threadblock_count; + int group_size; - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + typename EpilogueOutputOp::Params output_op; - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; - using ProblemVisitor = - GemmMoeProblemVisitor; + // Only used by device-level operator + GemmCoord* host_problem_sizes; // - // Structures + // Methods // - /// Argument structure - struct Arguments { - // - // Data members - // - - int problem_count; - int threadblock_count; - int group_size; - - typename EpilogueOutputOp::Params output_op; - - ElementA *ptr_A; - ElementB *ptr_B; - ElementScale *weight_scales; - ElementC *ptr_C; - ElementC *ptr_D; - - int64_t *total_rows_before_expert; - int64_t gemm_n; - int64_t gemm_k; - - // Only used by device-level operator - GemmCoord *host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0), threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), - ptr_C(nullptr), ptr_D(nullptr), total_rows_before_expert(nullptr), gemm_n(0), gemm_k(0), - host_problem_sizes(nullptr) {} - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, - ElementA const *ptr_A, ElementB const *ptr_B, ElementScale const *weight_scales, - ElementC const *ptr_C, ElementC *ptr_D, int64_t *total_rows_before_expert, int64_t gemm_n, - int64_t gemm_k, GemmCoord *host_problem_sizes = nullptr) - : problem_count(problem_count), threadblock_count(threadblock_count), group_size(group_size), - output_op(output_op), ptr_A(const_cast(ptr_A)), ptr_B(const_cast(ptr_B)), - weight_scales(const_cast(weight_scales)), ptr_C(const_cast(ptr_C)), - ptr_D(ptr_D), total_rows_before_expert(total_rows_before_expert), gemm_n(gemm_n), gemm_k(gemm_k), - host_problem_sizes(nullptr) { - if (platform::is_same::value || platform::is_same::value) { - assert(weight_scales); - } - } - }; + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr), total_rows_before_expert(nullptr), gemm_n(0), gemm_k(0), host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, + ElementC const* ptr_C, ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, + int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count), threadblock_count(threadblock_count), group_size(group_size), output_op(output_op), ptr_A(const_cast(ptr_A)), ptr_B(const_cast(ptr_B)), weight_scales(const_cast(weight_scales)), ptr_C(const_cast(ptr_C)), ptr_D(ptr_D), total_rows_before_expert(total_rows_before_expert), gemm_n(gemm_n), gemm_k(gemm_k), host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; - // - // Structure for precomputing values in host memory and passing to kernels - // + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; - /// Parameters structure - struct Params { - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - int group_size; - - typename EpilogueOutputOp::Params output_op; - - ElementA *ptr_A; - ElementB *ptr_B; - ElementScale *weight_scales; - ElementC *ptr_C; - ElementC *ptr_D; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0) - : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, - tile_count), - threadblock_count(args.threadblock_count), group_size(args.group_size), output_op(args.output_op), - ptr_A(args.ptr_A), ptr_B(args.ptr_B), weight_scales(args.weight_scales), ptr_C(args.ptr_C), - ptr_D(args.ptr_D) {} - - CUTLASS_HOST_DEVICE - void update(Arguments const &args, void *workspace = nullptr, int tile_count = 0) { - problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, - args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - weight_scales = args.weight_scales; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename ProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - public: // // Methods // - CUTLASS_DEVICE - MoeFCGemm() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const &problem_size) { return Status::kSuccess; } - - static Status can_implement(Arguments const &args) { - if (platform::is_same::value || platform::is_same::value) { - if (args.weight_scales == nullptr) { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); - return Status::kInvalid; - } - } else if (args.weight_scales != nullptr) { - CUTLASS_TRACE_HOST( - "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); - return Status::kInvalid; - } else if (args.group_size != args.gemm_k) { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); - return Status::kInvalid; - } - // Handle the case the input is too short - else if (static_cast(args.gemm_n) < Mma::IteratorB::AccessType::kElements) { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); - return Status::kInvalid; - } - return Status::kSuccess; + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, + tile_count), + threadblock_count(args.threadblock_count), + group_size(args.group_size), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, + args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; } - - static size_t get_extra_workspace_size(Arguments const &args, cutlass::gemm::GemmCoord const &grid_tiled_shape) { - return 0; + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { + if (platform::is_same::value || platform::is_same::value) { + if (args.weight_scales == nullptr) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } else if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } else if (args.group_size != args.gemm_k) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } + // Handle the case the input is too short + else if (static_cast(args.gemm_n) < Mma::IteratorB::AccessType::kElements) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } - CUTLASS_DEVICE - void run_kernel_(Params const ¶ms, SharedStorage &shared_storage) { - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert(platform::is_same::value && kInterleave == 1 || - platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); - // Outer 'persistent' loop to iterate over tiles - int loop = 0; - while (problem_visitor.next_tile()) { - loop++; - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, - int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - // Load element pointers. Exchange pointers and strides if working on the transpose - const int64_t rows_to_jump = - problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA *ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) { + loop++; - char *byte_ptr_B = ((char *)params.ptr_B) + problem_idx * bytes_per_expert_matrix; - ElementB *ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value ? gemm_n : gemm_k * kInterleave; + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, + int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); - cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = + problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; - // Compute position within threadblock - int thread_idx = threadIdx.x; + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, - tb_offset_A); + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; - typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, - {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, - thread_idx, tb_offset_B); + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - typename Mma::FragmentC accumulators; + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; - accumulators.clear(); + // Compute position within threadblock + int thread_idx = threadIdx.x; - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, + tb_offset_A); - int lane_idx = threadIdx.x % 32; + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, + thread_idx, tb_offset_B); - // - // Matrix multiply phase - // + typename Mma::FragmentC accumulators; - // Construct thread-scoped matrix multiply - auto CreateMMA = [&]() { - if constexpr (use_dq_gemm::value) - return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - else - return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - }; - Mma mma = CreateMMA(); + accumulators.clear(); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - ElementScale *weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + int lane_idx = threadIdx.x % 32; - if constexpr (use_dq_gemm::value) { - const MatrixCoord scale_extent = {1, problem_size.n()}; - typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), - weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); + // + // Matrix multiply phase + // - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } else { - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - } + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); - // - // Epilogue - // + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - EpilogueOutputOp output_op(params.output_op); + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); - ElementC *ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; - ElementC *ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - LayoutC layout_C(0); - LayoutC layout_D(gemm_n); + if constexpr (use_dq_gemm::value) { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } else { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, - threadblock_offset.mn()); + // + // Epilogue + // - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, - threadblock_offset.mn()); + EpilogueOutputOp output_op(params.output_op); - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); - // Next tile - problem_visitor.advance(gridDim.x); - } - } + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); - template - CUTLASS_DEVICE void run_kernel(Params const ¶ms, SharedStorage &shared_storage) { - if constexpr (platform::is_same::value) { - run_kernel_(params, shared_storage); - } else { - CUTLASS_NOT_IMPLEMENTED(); - } - } + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset.mn()); - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) #if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - run_kernel(params, shared_storage); + run_kernel(params, shared_storage); #elif (__CUDA_ARCH__ >= 900) - run_kernel(params, - shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. + run_kernel(params, + shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. #else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); #endif #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index 796dc2fe78d8e..c0529b2bbbea9 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -25,182 +25,144 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Visitor class to abstract away the algorithm for iterating over tiles template -struct BaseMoeProblemVisitor -{ - using ThreadblockShape = ThreadblockShape_; - - struct ProblemInfo - { - static int32_t const kNoPrefetchEntry = -1; - int32_t problem_idx; - int32_t problem_start; - - CUTLASS_DEVICE - ProblemInfo() - : problem_idx(kNoPrefetchEntry) - , problem_start(kNoPrefetchEntry) - { - } - - CUTLASS_DEVICE - ProblemInfo(int32_t problem_idx_, int32_t problem_start_) - : problem_idx(problem_idx_) - , problem_start(problem_start_) - { - } - }; - - struct Params - { - int64_t const* last_row_for_problem; - int64_t gemm_n; - int64_t gemm_k; - int32_t problem_count; - void const* workspace; - int32_t tile_count; - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - Params() - : last_row_for_problem(nullptr) - , gemm_n(0) - , gemm_k(0) - , problem_count(0) - , workspace(nullptr) - , tile_count(0) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, - void const* workspace = nullptr, int32_t tile_count = 0) - : last_row_for_problem(last_row_for_problem) - , gemm_n(gemm_n) - , gemm_k(gemm_k) - , problem_count(problem_count) - , workspace(workspace) - , tile_count(tile_count) - { - } - }; +struct BaseMoeProblemVisitor { + using ThreadblockShape = ThreadblockShape_; - Params const& params; - int32_t tile_idx; - int32_t problem_tile_start; + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; int32_t problem_idx; + int32_t problem_start; - // - // Methods - // CUTLASS_DEVICE - BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) - : params(params_) - , tile_idx(block_idx) - , problem_tile_start(0) - , problem_idx(0) - { - } - - /// Get the grid shape - CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) - { - - return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); - } - - /// Gets the global tile index - CUTLASS_HOST_DEVICE - int32_t tile_index() const - { - return tile_idx; - } - - /// Gets the index of the problem - CUTLASS_HOST_DEVICE - int32_t problem_index() const - { - return problem_idx; - } - - CUTLASS_HOST_DEVICE - int32_t threadblock_idx() const - { - return tile_idx - problem_tile_start; + ProblemInfo() + : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) { } CUTLASS_DEVICE - void advance(int32_t grid_size) - { - tile_idx += grid_size; + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_), problem_start(problem_start_) { } + }; - CUTLASS_HOST_DEVICE - static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) - { - ProblemSizeHelper::possibly_transpose_problem(problem); - } + struct Params { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; - /// Returns the problem size for the current problem - CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size() const - { - return problem_size(problem_idx); - } + // + // Methods + // + /// Ctor CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size(int idx) const - { - const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; - const int64_t current_problem_row = params.last_row_for_problem[idx]; - const int64_t gemm_m = current_problem_row - prev_problem_row; - GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); - ProblemSizeHelper::possibly_transpose_problem(problem); - return problem; + Params() + : last_row_for_problem(nullptr), gemm_n(0), gemm_k(0), problem_count(0), workspace(nullptr), tile_count(0) { } + /// Ctor CUTLASS_HOST_DEVICE - static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) - { - return ProblemSizeHelper::tile_count(grid); + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem), gemm_n(gemm_n), gemm_k(gemm_k), problem_count(problem_count), workspace(workspace), tile_count(tile_count) { } - - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) - { - int32_t total_tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - auto problem = host_problem_sizes_ptr[i]; - possibly_transpose_problem(problem); - auto grid = grid_shape(problem); - total_tiles += tile_count(grid); - } - - return total_tiles; + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_), tile_idx(block_idx), problem_tile_start(0), problem_idx(0) { + } + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) { + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); } + + return total_tiles; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// template + int PrefetchTileCount, int ThreadCount> struct MoeProblemVisitor; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -208,137 +170,122 @@ struct MoeProblemVisitor; // template struct MoeProblemVisitor : public BaseMoeProblemVisitor -{ - using Base = BaseMoeProblemVisitor; - using Params = typename Base::Params; - static int const kThreadCount = ThreadCount; - static bool const kRequiresPrecomputation = false; - static int const kThreadsPerWarp = 32; - - struct SharedStorage - { - }; - - // Final tile of the problem loaded by this thread. Each thread will hold - // a separate value. - int32_t problem_ending_tile; - - SharedStorage& shared_storage; - - // - // Methods - // - CUTLASS_DEVICE - MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, block_idx) - , problem_ending_tile(0) - , shared_storage(shared_storage_) - { - this->problem_idx = -1 * kThreadsPerWarp; - this->problem_tile_start = 0; + ThreadCount> : public BaseMoeProblemVisitor { + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage { + }; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx), problem_ending_tile(0), shared_storage(shared_storage_) { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; } - CUTLASS_DEVICE - bool next_tile() - { - // Check whether the tile to compute is within the range of the current problem. - int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); - if (this->tile_idx < problem_tile_end) - { - return true; + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; } + } - // Check whether the tile to compute is within the current group of problems fetched by the warp. - // The last tile for this group is the final tile of the problem held by the final thread in the warp. - int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - // Keep the starting problem for this group in `problem_idx`. This is done to reduce - // register pressure. The starting problem for this group is simply the first problem - // in the group most recently fetched by the warp. - int32_t& group_problem_start = this->problem_idx; - group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; - - // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce - // register pressure. - int32_t& group_tile_start = this->problem_tile_start; - - // Each thread in the warp processes a separate problem to advance until - // reaching a problem whose starting tile is less less than tile_idx. - while (group_tile_end <= this->tile_idx) - { - group_problem_start += kThreadsPerWarp; - if (group_problem_start > this->params.problem_count) - { - return false; - } - - // Since `group_tile_start` is a reference to `this->problem_tile_start`, this - // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` - // is also set here is used later in `next_tile`. - group_tile_start = group_tile_end; - - int lane_idx = threadIdx.x % kThreadsPerWarp; - int32_t lane_problem = group_problem_start + lane_idx; - - // Compute the number of tiles in the problem assigned to each thread. - problem_ending_tile = 0; - if (lane_problem < this->params.problem_count) - { - cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - problem_ending_tile = this->tile_count(grid); - } - - // Compute a warp-wide inclusive prefix sum to compute the ending tile index of - // each thread's problem. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kThreadsPerWarp; i <<= 1) - { - int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); - if (lane_idx >= i) - { - problem_ending_tile += val; - } - } - - // The total tile count for this group is now in the final position of the prefix sum - int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - problem_ending_tile += group_tile_start; - group_tile_end += tiles_in_group; - } + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - // The next problem to process is the first one that does not have ending tile position - // that is greater than or equal to tile index. - int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } - this->problem_idx = group_problem_start + problem_idx_in_group; + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); - // The starting tile for this problem is the ending tile of the previous problem. In cases - // where `problem_idx_in_group` is the first problem in the group, we do not need to reset - // `problem_tile_start`, because it is set to the previous group's ending tile in the while - // loop above. - if (problem_idx_in_group > 0) - { - this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); - } + this->problem_idx = group_problem_start + problem_idx_in_group; - return true; + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); } - static size_t get_workspace_size( - cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) - { - return 0; - } + return true; + } - static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, - int32_t block_count, void* host_workspace_ptr) - { - } + static size_t get_workspace_size( + cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) { + return 0; + } + + static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) { + } }; -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h index 5e3531f093811..79b2c1c12c2da 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h @@ -49,446 +49,361 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ +namespace cutlass { +namespace gemm { +namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -struct SplitkGemmGrouped -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; +template +struct SplitkGemmGrouped { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // - using ElementFinalOutput = typename MapArguments::ElementA; + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; + typename EpilogueOutputOp::Params output_op; - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; + // Only used by device-level operator + GemmCoord* host_problem_sizes; - using ProblemVisitor - = GemmGroupedProblemVisitor; + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; // - // Structures + // Methods // - /// Argument structure - struct Arguments - { - - // - // Data members - // - - GemmCoord* problem_sizes; - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // splitK - int split_k_slices; - int64_t* splitk_buffer_offsets; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, - typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, - ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, - typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, - typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, - int64_t* splitk_buffer_offsets) - : problem_sizes(problem_sizes) - , problem_count(problem_count) - , threadblock_count(threadblock_count) - , output_op(output_op) - , ptr_A(ptr_A) - , ptr_B(ptr_B) - , ptr_C(ptr_C) - , ptr_D(ptr_D) - , lda(lda) - , ldb(ldb) - , ldc(ldc) - , ldd(ldd) - , host_problem_sizes(host_problem_sizes) - , split_k_slices(split_k_slices) - , splitk_buffer_offsets(splitk_buffer_offsets) - { - } - }; + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), lda(nullptr), ldb(nullptr), ldc(nullptr), ldd(nullptr), host_problem_sizes(nullptr), split_k_slices(1), splitk_buffer_offsets(nullptr) { + } - // - // Structure for precomputing values in host memory and passing to kernels - // + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, + typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, + ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes), problem_count(problem_count), threadblock_count(threadblock_count), output_op(output_op), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), host_problem_sizes(host_problem_sizes), split_k_slices(split_k_slices), splitk_buffer_offsets(splitk_buffer_offsets) { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - ElementC* ptr_C_split; - ElementC* ptr_D_split; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // - // Methods - // - - // splitk - GemmCoord grid_tiled_shape; - int swizzle_log_tile; - int gemm_k_size; - GemmCoord* host_problem_sizes; - int split_k_slices; - int64_t* splitk_buffer_offsets; - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , ptr_C_split(nullptr) - , ptr_D_split(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , swizzle_log_tile(0) - , gemm_k_size(0) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) - , host_problem_sizes(args.host_problem_sizes) - , threadblock_count(args.threadblock_count) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , ptr_C_split((ElementC*) workspace) - , ptr_D_split((ElementC*) workspace) - , lda(args.lda) - , ldb(args.ldb) - , ldc(args.ldc) - , ldd(args.ldd) - , split_k_slices(args.split_k_slices) - , splitk_buffer_offsets(args.splitk_buffer_offsets) - { - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); - swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); - - // only support same k - int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; - int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); - - gemm_k_size = gemm_k_iterations * Mma::Shape::kK; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = - typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - ptr_C_split = workspace; - ptr_D_split = workspace; - - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldd = args.ldd; - } - }; - - /// Shared memory storage structure - struct SharedStorage - { - union - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - // ProblemVisitor shared storage can't be overlapped with others - typename ProblemVisitor::SharedStorage problem_visitor; - }; - -public: // // Methods // - CUTLASS_DEVICE - SplitkGemmGrouped() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), ptr_C_split(nullptr), ptr_D_split(nullptr), lda(nullptr), ldb(nullptr), ldc(nullptr), ldd(nullptr), swizzle_log_tile(0), gemm_k_size(0), host_problem_sizes(nullptr), split_k_slices(1), splitk_buffer_offsets(nullptr) { } - static Status can_implement(Arguments const& args) - { - return Status::kSuccess; - } + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), host_problem_sizes(args.host_problem_sizes), threadblock_count(args.threadblock_count), output_op(args.output_op), ptr_A(args.ptr_A), ptr_B(args.ptr_B), ptr_C(args.ptr_C), ptr_D(args.ptr_D), ptr_C_split((ElementC*)workspace), ptr_D_split((ElementC*)workspace), lda(args.lda), ldb(args.ldb), ldc(args.ldc), ldd(args.ldd), split_k_slices(args.split_k_slices), splitk_buffer_offsets(args.splitk_buffer_offsets) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + // only support same k + int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + } - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = + typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) - { + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - // Load element pointers. Exchange pointers and strides if working on the transpose - ElementA* ptr_A - = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); - typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA* ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); - ElementB* ptr_B - = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); - typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + ElementB* ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, - int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); + cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - threadblock_tile_offset.k() * params.gemm_k_size, - }; + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k; - if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) - { - problem_size_k = problem_size.k(); - } - else - { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { + problem_size_k = problem_size.k(); + } else { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - // Compute position within threadblock - int thread_idx = threadIdx.x; + // Compute position within threadblock + int thread_idx = threadIdx.x; - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - typename Mma::IteratorB iterator_B( - LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); - typename Mma::FragmentC accumulators; + typename Mma::FragmentC accumulators; - accumulators.clear(); + accumulators.clear(); - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); - int lane_idx = threadIdx.x % 32; + int lane_idx = threadIdx.x % 32; - // - // Matrix multiply phase - // + // + // Matrix multiply phase + // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - // - // Epilogue - // + // + // Epilogue + // - EpilogueOutputOp output_op(params.output_op); + EpilogueOutputOp output_op(params.output_op); - ElementC* ptr_C = params.ptr_C_split; - ElementC* ptr_D = params.ptr_D_split; + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; - LayoutC layout_C(params.ldc[problem_idx]); - LayoutC layout_D(params.ldd[problem_idx]); + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); - // assume identity swizzle - MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); - iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); - iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); + iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); - // Next tile - problem_visitor.advance(gridDim.x); - } + // Next tile + problem_visitor.advance(gridDim.x); } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index c7f51d6fe9fc2..79093fba11674 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -57,12 +57,9 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. @@ -92,210 +89,187 @@ template < bool AccumulatorsInRowMajor = false, /// Used for partial specialization typename Enable = bool> -class MmaTensorOpComputeBWithF16 -{ -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; - /// Data type of multiplicand A - using ElementA = ElementA_; + /// Data type of multiplicand A + using ElementA = ElementA_; - /// Layout of multiplicand A - using LayoutA = LayoutA_; + /// Layout of multiplicand A + using LayoutA = LayoutA_; - /// Data type of multiplicand B - using ElementB = ElementB_; + /// Data type of multiplicand B + using ElementB = ElementB_; - /// Layout of multiplicand B - using LayoutB = LayoutB_; + /// Layout of multiplicand B + using LayoutB = LayoutB_; - /// Data type of accumulator matrix C - using ElementC = ElementC_; + /// Data type of accumulator matrix C + using ElementC = ElementC_; - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports underlying HMMA"); + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value && platform::is_same::value) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + static_assert(platform::is_same::value || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - static_assert( - SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); - static_assert( - SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; -public: - /// Iterates over the A operand in memory - using IteratorA - = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, - MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - /// Storage for transformed A tile - using TransformedFragmentA = Array; + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, - LayoutB, MatrixShape, Policy::OpDelta::kRow, - kThreadCount, kPartitionsK>; + /// Storage for transformed A tile + using TransformedFragmentA = Array; - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; - /// Storage for transformed B tile - using TransformedFragmentB = Array; + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + /// Storage for transformed B tile + using TransformedFragmentB = Array; - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; -public: - // - // Methods - // + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} + public: + // + // Methods + // - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, - int const warp_tileB_k_offset) const - { + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " - "B"); + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); - D = C; + D = C; - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); } + } + } #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); } + } + } #else - assert(0); + assert(0); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace warp -} // namespace gemm -} // namespace cutlass +} // namespace warp +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h index 3e6fa20d1754c..a744441c95bdf 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h @@ -20,102 +20,104 @@ namespace ort_fastertransformer { // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape // in the kernel layout details when doing weight only quantization. enum class CutlassTileConfig { - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, - - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, - - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x64x128_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, - - // Warp configs for M=128 - CtaShape128x64x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x64x64, - CtaShape128x128x64_WarpShape128x32x64, - CtaShape128x256x64_WarpShape64x64x64, - - // Warp configs for M=256 - CtaShape256x128x64_WarpShape64x64x64, - - // TensorCore config CTA_N = 256, CTA_K = 64 - CtaShape16x256x64_WarpShape16x64x64 + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64 }; enum class SplitKStyle { - NO_SPLIT_K, - SPLIT_K_SERIAL, - // SPLIT_K_PARALLEL // Not supported yet + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet }; enum class CutlassTileConfigSM90 { - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, - - // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, }; enum class MainloopScheduleType { - AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this - // defaults to the "legacy" main loop schedule. + AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. }; enum class EpilogueScheduleType { - AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For - // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. + AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. }; -enum class ClusterShape { ClusterShape_1x1x1, ClusterShape_2x1x1, ClusterShape_1x2x1, ClusterShape_2x2x1 }; +enum class ClusterShape { ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1 }; struct CutlassGemmConfig { - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; - - // config options for sm90 - CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; - MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; - EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; - ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; - - CutlassGemmConfig() {} - - CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) - : tile_config(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages) {} - - CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, - EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) - : tile_config_sm90(tile_config_sm90), mainloop_schedule(mainloop_schedule), - epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape) {} + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + + CutlassGemmConfig() {} + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages) {} + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape) {} }; -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680e699..5833020612c8e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h @@ -40,408 +40,373 @@ #include "cutlass/half.h" #include "cutlass/numeric_types.h" -namespace cutlass -{ +namespace cutlass { // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ +struct FastInterleavedAndBiasedNumericArrayConverter { }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); - - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t* bf16_result_ptr = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; - - // Construct FP32s, bfloat does not have enough mantissa for IADD trick - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); - - // Subtract out fp32_base + 128 to make the unsigned integer signed. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 4; ++ii) - { - fp32_intermediates[ii] -= 8388736.f; - } - - // Truncate the fp32 representation and pack up as bfloat16s. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 2; ++ii) - { - bf16_result_ptr[ii] - = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - result.clear(); // Suppress compiler warning - arch::device_breakpoint(); -#endif - return result; + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - static constexpr uint32_t NEG_72 = 0xd480d480; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. - // No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - CUTLASS_PRAGMA_UNROLL - for (int ii = 1; ii < result_type::kElements / 2; ++ii) - { - i4s >>= sizeof_bits::value; - // (i4s & 0x000f000f) | 0x43004300 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; - - // Finally, we construct the output numbers. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < result_type::kElements / 2; ++ii) - { - // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - arch::device_breakpoint(); - result.clear(); // Suppress compiler warning. -#endif - return result; + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); } - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h index 5a0cd2957082a..e5abefa35bc84 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h @@ -38,29 +38,24 @@ #include "cutlass/matrix_coord.h" #include "cutlass/pitch_linear_coord.h" -namespace cutlass -{ -namespace layout -{ +namespace cutlass { +namespace layout { template -struct ColumnMajorTileInterleave -{ - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; +struct ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; template -struct IsColumnMajorTileInterleave -{ - static constexpr bool value = false; +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; }; template -struct IsColumnMajorTileInterleave> -{ - static constexpr bool value = true; +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; }; -} // namespace layout -} // namespace cutlass +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h index 0d04310f6b85d..40ee5a46e7850 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -50,12 +50,9 @@ //////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace transform -{ -namespace threadblock -{ +namespace cutlass { +namespace transform { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -63,186 +60,170 @@ template -class FineGrainedScaleZeroIterator -{ -public: - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = 0; - static int const kAlignment = Alignment_; - - static int const kAccessesPerVector = 1; - - /// Row index of scales corresponding to the groupsize of 64 - int row_groupsize64_; - int group_size_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - using Pointer = Element*; - using NonConstPointer = typename platform::remove_const::type*; - - using AccessType = AlignedArray; - - // For compatibility with existing iterator interface - struct Params - { - LongIndex stride_ = 0; - - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_ = 0; - - // Default ctor - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : stride_(layout.stride(0)) - { - inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; - } - }; - -private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char*; - -private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const params_; - - /// Internal pointer to first access of tile - BytePointer pointer_scale_; - BytePointer pointer_zero_; - - bool is_valid_ = false; - -public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_DEVICE - FineGrainedScaleZeroIterator( - ///< Precomputed parameters object - Params const& params, - ///< Pointer to start of scale tensor - Pointer pointer_scale, - ///< Pointer to start of zero tensor - Pointer pointer_zero, - ///< Extent of the scale and bias - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const& threadblock_offset, - ///< Group size - int group_size) - : params_(params) - , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) - , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) - { - row_groupsize64_ = threadblock_offset.row(); - group_size_ = group_size; - - const LongIndex tb_row_byte_offset - = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; - const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; - pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); - - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); - } - - static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; - - int const thread_row = thread_id / THREADS_PER_ROW; - int const thread_col = thread_id % THREADS_PER_ROW; - - const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; - const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; - pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); - } - - // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on - // a given iteration. The same threads will be responsible for issues reads since the number of scales - // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ - // outside of the constructor. - int const global_row = threadblock_offset.row() + thread_row; - int const global_col = threadblock_offset.column() + thread_col * kAlignment; - - bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; - bool const col_in_bounds = global_col < extent.column(); - - is_valid_ = row_in_bounds && col_in_bounds; - } +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object - Pointer pointer_scale, ///< Pointer to start of scale tensor - Pointer pointer_zero, ///< Pointer to start of zero tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - int group_size) - : FineGrainedScaleZeroIterator( - params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) - { - } + static int const kAccessesPerVector = 1; - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& tile_offset) - { - const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; - const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; - pointer_scale_ += row_byte_offset + col_byte_offset; - if (pointer_zero_ != nullptr) - { - pointer_zero_ += row_byte_offset + col_byte_offset; - } - } + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) - { - is_valid_ &= (!enable); - } + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; - /// Returns whether access is valid or not + // Default ctor CUTLASS_HOST_DEVICE - bool valid() const - { - return is_valid_; - } + Params() {} - /// Returns a scale pointer + /// Construct the Params object given a pitch-linear tensor's layout CUTLASS_HOST_DEVICE - AccessType* get_scale() const - { - return reinterpret_cast(pointer_scale_); + Params(Layout const& layout) + : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), pointer_scale_(reinterpret_cast(const_cast(pointer_scale))), pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); } - /// Returns a zero pointer - CUTLASS_HOST_DEVICE - AccessType* get_zero() const - { - return reinterpret_cast(pointer_zero_); + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { + return reinterpret_cast(pointer_zero_); + } }; -} // namespace threadblock -} // namespace transform -} // namespace cutlass +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h index 64774428e9f90..f3f69d9a0097b 100644 --- a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h @@ -34,25 +34,21 @@ #pragma once -namespace cutlass -{ +namespace cutlass { -enum class WeightOnlyQuantOp -{ - UNDEFINED, - PER_COLUMN_SCALE_ONLY, - FINEGRAINED_SCALE_ONLY, - FINEGRAINED_SCALE_AND_ZEROS +enum class WeightOnlyQuantOp { + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS }; -constexpr bool isFinegrained(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; } -constexpr bool hasZero(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +constexpr bool hasZero(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; } -} // namespace cutlass +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h index 899c4de9b44d4..0f75a121b3b92 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h @@ -30,10 +30,10 @@ namespace ort_fastertransformer { std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); -CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector &candidate_configs, - const std::vector &occupancies, const int64_t m, +CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, + const std::vector& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only); -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index d79e3e085f1f9..13ebbe4888911 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -21,44 +21,51 @@ namespace ort_fastertransformer { -enum class ActivationType { Gelu, Relu, Silu, GeGLU, ReGLU, SiGLU, Identity, InvalidType }; +enum class ActivationType { Gelu, + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + Identity, + InvalidType }; template class MoeGemmRunner { - public: - MoeGemmRunner(); + public: + MoeGemmRunner(); - void initialize(int sm); + void initialize(int sm); - void moe_gemm_bias_act(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, ActivationType activation_type, cudaStream_t stream); + void moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, ActivationType activation_type, cudaStream_t stream); - // void moe_gemm_act(const T *A, const WeightType *B, const T *weight_scales, T *C, int64_t - // *total_rows_before_expert, - // int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - // ActivationType activation_type, cudaStream_t stream); + // void moe_gemm_act(const T *A, const WeightType *B, const T *weight_scales, T *C, int64_t + // *total_rows_before_expert, + // int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + // ActivationType activation_type, cudaStream_t stream); - void moe_gemm(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, cudaStream_t stream); + void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream); - private: - template - void dispatch_to_arch(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, - int *occupancy = nullptr); + private: + template + void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, + int* occupancy = nullptr); - template - void run_gemm(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, cudaStream_t stream); + template + void run_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream); - private: - int sm_; - int multi_processor_count_; + private: + int sm_; + int multi_processor_count_; }; -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index ab5b42124db2a..476a4e82f8596 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -64,312 +64,312 @@ namespace ort_fastertransformer { // ============================= Variable batched Gemm things =========================== template -void generic_moe_gemm_kernelLauncher(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, +void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, - cudaStream_t stream, int *kernel_occupancy = nullptr) { - // if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { - // ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); - // } - - // static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - // "Specialized for half, float"); - - // static_assert(cutlass::platform::is_same::value || - // cutlass::platform::is_same::value || - // cutlass::platform::is_same::value, - // ""); - - // // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - // using ElementType_ = - // typename cutlass::platform::conditional::value, cutlass::half_t, - // T>::type; - // using ElementType = ElementType_; - - // using CutlassWeightType_ = - // typename cutlass::platform::conditional::value, cutlass::half_t, - // WeightType>::type; - // using CutlassWeightType = CutlassWeightType_; - - // // We need separate config for each architecture since we will target different tensorcore instructions. For - // float, - // // we do not target TCs. - // using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - // using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - // using EpilogueOp = - // typename Epilogue::Op; - - // // Finally, set up the kernel. - // using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< - // ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, - // MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB, - // cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, - // cutlass::layout::RowMajor, ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, - // ThreadblockShape, WarpShape, typename MixedGemmArchTraits::InstructionShape, EpilogueOp, - // cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, - // cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; - - // using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; - - // using GemmGrouped = cutlass::gemm::device::GemmGrouped; - - // if (kernel_occupancy != nullptr) { - // *kernel_occupancy = compute_occupancy_for_kernel(); - // return; - // } - // int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - // if (occupancy == 0) { - // ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); - // } - // const int threadblock_count = multi_processor_count * occupancy; - - // typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); - - // typename GemmGrouped::Arguments args( - // num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), - // reinterpret_cast(B), reinterpret_cast(weight_scales), - // reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, - // gemm_n, gemm_k); - - // GemmGrouped gemm; - - // auto can_implement = gemm.can_implement(args); - // if (can_implement != cutlass::Status::kSuccess) { - // std::string err_msg = - // "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); - // ORT_THROW("[FT Error][MoE Runner] " + err_msg); - // } - - // auto init_status = gemm.initialize(args); - // if (init_status != cutlass::Status::kSuccess) { - // std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + - // std::string(cutlassGetStatusString(init_status)); - // ORT_THROW("[FT Error][MoE Runner] " + err_msg); - // } - - // auto run_status = gemm.run(stream); - // if (run_status != cutlass::Status::kSuccess) { - // std::string err_msg = - // "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - // ORT_THROW("[FT Error][MoE Runner] " + err_msg); - // } - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); - - static_assert(cutlass::platform::is_same::value || - cutlass::platform::is_same::value || - cutlass::platform::is_same::value, - ""); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; - using ElementType = ElementType_; - - using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; - - using CutlassWeightType = CutlassWeightType_; - - // We need separate config for each architecture since we will target different tensorcore instructions. For - // float, we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = - typename Epilogue::Op; - - // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, - MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB, - cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, - cutlass::layout::RowMajor, ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, - ThreadblockShape, WarpShape, typename MixedGemmArchTraits::InstructionShape, EpilogueOp, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, - cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; - - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - - if (kernel_occupancy != nullptr) { - *kernel_occupancy = compute_occupancy_for_kernel(); - return; - } - int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - ORT_ENFORCE(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); - int const threadblock_count = multi_processor_count * occupancy; - - typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), - biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - - int const group_size = gemm_k; - typename GemmGrouped::Arguments args( - num_experts, threadblock_count, group_size, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), reinterpret_cast(weight_scales), - reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, - gemm_n, gemm_k); - - GemmGrouped gemm; - - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) { - std::string err_msg = - "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); - ORT_THROW("[FT Error][MoE Runner] " + err_msg); - } - - auto init_status = gemm.initialize(args); - if (init_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + - std::string(cutlassGetStatusString(init_status)); - ORT_THROW("[FT Error][MoE Runner] " + err_msg); - } - - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) { - std::string err_msg = - "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - ORT_THROW("[FT Error][MoE Runner] " + err_msg); - } + cudaStream_t stream, int* kernel_occupancy = nullptr) { + // if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { + // ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); + // } + + // static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + // "Specialized for half, float"); + + // static_assert(cutlass::platform::is_same::value || + // cutlass::platform::is_same::value || + // cutlass::platform::is_same::value, + // ""); + + // // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + // using ElementType_ = + // typename cutlass::platform::conditional::value, cutlass::half_t, + // T>::type; + // using ElementType = ElementType_; + + // using CutlassWeightType_ = + // typename cutlass::platform::conditional::value, cutlass::half_t, + // WeightType>::type; + // using CutlassWeightType = CutlassWeightType_; + + // // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, + // // we do not target TCs. + // using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + // using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + // using EpilogueOp = + // typename Epilogue::Op; + + // // Finally, set up the kernel. + // using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + // ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + // MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB, + // cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, + // cutlass::layout::RowMajor, ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, + // ThreadblockShape, WarpShape, typename MixedGemmArchTraits::InstructionShape, EpilogueOp, + // cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, + // cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; + + // using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + // using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + // if (kernel_occupancy != nullptr) { + // *kernel_occupancy = compute_occupancy_for_kernel(); + // return; + // } + // int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + // if (occupancy == 0) { + // ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); + // } + // const int threadblock_count = multi_processor_count * occupancy; + + // typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + + // typename GemmGrouped::Arguments args( + // num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), + // reinterpret_cast(B), reinterpret_cast(weight_scales), + // reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, + // gemm_n, gemm_k); + + // GemmGrouped gemm; + + // auto can_implement = gemm.can_implement(args); + // if (can_implement != cutlass::Status::kSuccess) { + // std::string err_msg = + // "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + // ORT_THROW("[FT Error][MoE Runner] " + err_msg); + // } + + // auto init_status = gemm.initialize(args); + // if (init_status != cutlass::Status::kSuccess) { + // std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + + // std::string(cutlassGetStatusString(init_status)); + // ORT_THROW("[FT Error][MoE Runner] " + err_msg); + // } + + // auto run_status = gemm.run(stream); + // if (run_status != cutlass::Status::kSuccess) { + // std::string err_msg = + // "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + // ORT_THROW("[FT Error][MoE Runner] " + err_msg); + // } + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); + + static_assert(cutlass::platform::is_same::value || + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, + WeightType>::type; + + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = + typename Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + ElementType, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessA, CutlassWeightType, typename MixedGemmArchTraits::LayoutB, + cutlass::ComplexTransform::kNone, MixedGemmArchTraits::ElementsPerAccessB, ElementType, + cutlass::layout::RowMajor, ElementAccumulator, typename MixedGemmArchTraits::OperatorClass, arch, + ThreadblockShape, WarpShape, typename MixedGemmArchTraits::InstructionShape, EpilogueOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) { + *kernel_occupancy = compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + ORT_ENFORCE(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), + biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + + int const group_size = gemm_k; + typename GemmGrouped::Arguments args( + num_experts, threadblock_count, group_size, epilogue_op, reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, + gemm_n, gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[FT Error][MoE Runner] " + err_msg); + } } template struct dispatch_stages { - static void dispatch(const T * /*A*/, const WeightType * /*B*/, const T * /*weight_scales*/, const T * /*biases*/, - T * /*C*/, int64_t * /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, - int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, - cudaStream_t /*stream*/, [[maybe_unused]] int *occupancy = nullptr) { - std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + - std::to_string(arch::kMinComputeCapability) + " with stages set to " + - std::to_string(Stages); - ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); - } + static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/, + T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, + int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, + cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + " with stages set to " + + std::to_string(Stages); + ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); + } }; template struct dispatch_stages { - static void dispatch(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int *occupancy = nullptr) { - generic_moe_gemm_kernelLauncher( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - } + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } }; template struct dispatch_stages 2)>::type> { - static void dispatch(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int *occupancy = nullptr) { - generic_moe_gemm_kernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, - gemm_k, num_experts, gemm_config, multi_processor_count, stream, - occupancy); - } + static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, + CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, + gemm_k, num_experts, gemm_config, multi_processor_count, stream, + occupancy); + } }; template -void dispatch_gemm_config(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, +void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int *occupancy = nullptr) { - switch (gemm_config.stages) { + int* occupancy = nullptr) { + switch (gemm_config.stages) { case 2: - using DispatcherStages2 = dispatch_stages; - DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; case 3: - using DispatcherStages3 = dispatch_stages; - DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; case 4: - using DispatcherStages4 = dispatch_stages; - DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count, stream, occupancy); + break; default: - std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); - ORT_THROW("[FT Error][MoE][dispatch_gemm_config] " + err_msg); - break; - } + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + break; + } } // This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. // This overload is only enabled when T == WeightType. template < typename T, typename WeightType, typename arch, typename EpilogueTag, - typename std::enable_if::value && std::is_same::value>::type * = nullptr> -void dispatch_moe_gemm_to_cutlass(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + typename std::enable_if::value && std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, - int multi_processor_count, cudaStream_t stream, int *occupancy = nullptr) { - switch (gemm_config.tile_config) { + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) { - dispatch_gemm_config, - cutlass::gemm::GemmShape<16, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - } - break; - case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) { - dispatch_gemm_config, - cutlass::gemm::GemmShape<16, 64, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - } - break; - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>( + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; - case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 64, 64>>( + } + break; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, 64>>( A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, stream, occupancy); - break; + } + break; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 64, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - break; + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; case CutlassTileConfig::Undefined: - ORT_THROW("GEMM config undefined."); - break; + ORT_THROW("GEMM config undefined."); + break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW("GEMM config should have already been set by heuristic."); - break; + ORT_THROW("GEMM config should have already been set by heuristic."); + break; default: - ORT_THROW("Config is invalid for same type tensorop GEMM."); - break; - } + ORT_THROW("Config is invalid for same type tensorop GEMM."); + break; + } } // Tensorop GEMM overload @@ -377,159 +377,162 @@ void dispatch_moe_gemm_to_cutlass(const T *A, const WeightType *B, const T *weig // compile time template < typename T, typename WeightType, typename arch, typename EpilogueTag, - typename std::enable_if::value && !std::is_same::value>::type * = nullptr> -void dispatch_moe_gemm_to_cutlass(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + typename std::enable_if::value && !std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int sm_version, - int multi_processor_count, cudaStream_t stream, int *occupancy = nullptr) { - switch (gemm_config.tile_config) { + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - break; + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - break; + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<128, 32, 64>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - break; + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; case CutlassTileConfig::Undefined: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); - break; + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); - break; + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; default: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); - break; - } + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); + break; + } } // This overload will handle simt gemms. It is disabled via SFINAE for tensorop. template ::value>::type * = nullptr> -void dispatch_moe_gemm_to_cutlass(const T *A, const WeightType *B, const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + typename std::enable_if::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, - int multi_processor_count, cudaStream_t stream, int *occupancy = nullptr) { - switch (gemm_config.tile_config) { + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 64, 8>>( - A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, stream, occupancy); - break; + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 8>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + break; case CutlassTileConfig::Undefined: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); - break; + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); + break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by " - "heuristic."); - break; + ORT_THROW( + "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by " + "heuristic."); + break; default: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); - break; - } + ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); + break; + } } -template MoeGemmRunner::MoeGemmRunner() {} +template +MoeGemmRunner::MoeGemmRunner() {} -template void MoeGemmRunner::initialize(int sm_version) { - int device{-1}; - cudaGetDevice(&device); - sm_ = sm_version; - cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); +template +void MoeGemmRunner::initialize(int sm_version) { + int device{-1}; + cudaGetDevice(&device); + sm_ = sm_version; + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); } template template -void MoeGemmRunner::dispatch_to_arch(const T *A, const WeightType *B, - const T *weight_scales, const T *biases, T *C, - int64_t *total_rows_before_expert, int64_t total_rows, +void MoeGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, + const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, cudaStream_t stream, - int *occupancy) { - if (sm_ >= 70 && sm_ < 75) { - dispatch_moe_gemm_to_cutlass( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, - gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 75 && sm_ < 80) { - dispatch_moe_gemm_to_cutlass( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, - gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { - dispatch_moe_gemm_to_cutlass( - A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, - gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else { - ORT_THROW("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); - } + int* occupancy) { + if (sm_ >= 70 && sm_ < 75) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, sm_, multi_processor_count_, stream, occupancy); + } else if (sm_ >= 80 && sm_ < 90) { + dispatch_moe_gemm_to_cutlass( + A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, sm_, multi_processor_count_, stream, occupancy); + } else { + ORT_THROW("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + } } template template -void MoeGemmRunner::run_gemm(const T *A, const WeightType *B, const T *weight_scales, - const T *biases, T *C, int64_t *total_rows_before_expert, +void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { - static constexpr bool is_weight_only = !std::is_same::value; - static constexpr bool only_simt_configs = std::is_same::value; - std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); - std::vector occupancies(candidate_configs.size()); - - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]); - } - - static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. - static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. - CutlassGemmConfig chosen_config = - estimate_best_config_from_occupancies(candidate_configs, occupancies, total_rows, gemm_n, gemm_k, num_experts, - split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); - - dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, chosen_config, stream); + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, candidate_configs[ii], stream, &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + CutlassGemmConfig chosen_config = + estimate_best_config_from_occupancies(candidate_configs, occupancies, total_rows, gemm_n, gemm_k, num_experts, + split_k_limit, workspace_bytes, multi_processor_count_, is_weight_only); + + dispatch_to_arch(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, chosen_config, stream); } template -void MoeGemmRunner::moe_gemm_bias_act(const T *A, const WeightType *B, const T *weight_scales, - const T *biases, T *C, int64_t *total_rows_before_expert, +void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, + const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { - switch (activation_type) { + switch (activation_type) { case ActivationType::Relu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); - break; + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; case ActivationType::Gelu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); - break; - case ActivationType::Silu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); - break; + break; + case ActivationType::Silu: + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; case ActivationType::Identity: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); - break; + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; case ActivationType::InvalidType: - ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); - break; + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + break; default: { - ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); - } + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); } + } } // template @@ -565,16 +568,16 @@ void MoeGemmRunner::moe_gemm_bias_act(const T *A, const WeightTyp // } template -void MoeGemmRunner::moe_gemm(const T *A, const WeightType *B, const T *weight_scales, const T *biases, - T *C, int64_t *total_rows_before_expert, int64_t total_rows, int64_t gemm_n, +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, + T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { - // if (biases != nullptr) { - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); - // } else { - // run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, - // gemm_k, num_experts, stream); - // } + // if (biases != nullptr) { + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + // } else { + // run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, + // gemm_k, num_experts, stream); + // } } -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 90b7da255081a..50292f186df15 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -226,9 +226,9 @@ def __init__(self, config, batch_size, sequence_length): w2_list = [] w3_list = [] for i in range(self.num_experts): - w1_list.append(self.experts[i].w1.weight.transpose(0, 1)) - w2_list.append(self.experts[i].w2.weight.transpose(0, 1)) - w3_list.append(self.experts[i].w3.weight.transpose(0, 1)) + w1_list.append(self.experts[i].w1.weight) + w2_list.append(self.experts[i].w2.weight) + w3_list.append(self.experts[i].w3.weight) self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index dbf6ee7dabb0e..aa480a1af4587 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -249,9 +249,9 @@ def __init__( num_experts, in_features, hidden_features, - self.moe_experts.weight1, + self.moe_experts.weight1.transpose(1, 2), self.moe_experts.bias1, - self.moe_experts.weight2, + self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, )