diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6ccf063c71290..66009e5b3e229 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -51,6 +51,7 @@ set(contrib_ops_excluded_files "math/gemm_float8.cc" "math/gemm_float8.cu" "math/gemm_float8.h" + "moe/*" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h index 793c26114b8aa..aba235a62f94f 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/common.h @@ -25,56 +25,50 @@ #include "stdio.h" #include -namespace fastertransformer -{ +namespace fastertransformer { -static const char *_cudaGetErrorEnum(cublasStatus_t error) -{ - switch (error) - { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; +static const char* _cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; } return ""; } -static const char *_cudaGetErrorEnum(cudaError_t error) -{ +static const char* _cudaGetErrorEnum(cudaError_t error) { return cudaGetErrorString(error); } template -void check(T result, char const *const func, const char *const file, int const line) -{ - if (result) - { +void check(T result, char const* const func, const char* const file, int const line) { + if (result) { throw std::runtime_error(std::string("[FT][ERROR] CUDA runtime error: ") + (_cudaGetErrorEnum(result)) + " " + file + ":" + std::to_string(line) + " \n"); @@ -83,4 +77,4 @@ void check(T result, char const *const func, const char *const file, int const l #define check_cuda_error(val) fastertransformer::check((val), #val, __FILE__, __LINE__) -} // namespace fastertransformer +} // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 75bad78929fb4..90f7e78eaf86a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -22,30 +22,28 @@ namespace fastertransformer { -template -inline int compute_occupancy_for_kernel() -{ - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) { - cudaError_t status = - cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (status == cudaError::cudaErrorInvalidValue) { - // Clear the error bit since we can ignore this. - // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an - // occupancy of 0. This will cause the heuristic to ignore this configuration. - status = cudaGetLastError(); - return 0; - } - check_cuda_error(status); +template +inline int compute_occupancy_for_kernel() { + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status == cudaError::cudaErrorInvalidValue) { + // Clear the error bit since we can ignore this. + // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an + // occupancy of 0. This will cause the heuristic to ignore this configuration. + status = cudaGetLastError(); + return 0; } + check_cuda_error(status); + } - int max_active_blocks = -1; - check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + int max_active_blocks = -1; + check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); - return max_active_blocks; + return max_active_blocks; } } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index a4eb77ab2c557..71a88fa191202 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -23,185 +23,173 @@ namespace fastertransformer { struct TileShape { - int m; - int n; + int m; + int n; }; -TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) -{ - switch (tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - return TileShape{32, 128}; - case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - return TileShape{64, 128}; - case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - return TileShape{128, 128}; - default: - throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config"); - } +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + default: + throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config"); + } } -bool is_valid_split_k_factor(const int64_t m, - const int64_t n, - const int64_t k, +bool is_valid_split_k_factor(const int64_t m, + const int64_t n, + const int64_t k, const TileShape tile_shape, - const int split_k_factor, - const size_t workspace_bytes, - const bool is_weight_only) -{ - - // All tile sizes have a k_tile of 64. - static constexpr int k_tile = 64; - - // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k - if (is_weight_only) { - if ((k % k_tile) != 0) { - return false; - } + const int split_k_factor, + const size_t workspace_bytes, + const bool is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } - if ((k % split_k_factor) != 0) { - return false; - } + if ((k % split_k_factor) != 0) { + return false; + } - const int k_elements_per_split = k / split_k_factor; - if ((k_elements_per_split % k_tile) != 0) { - return false; - } + const int k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; } + } - // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; - if (required_ws_bytes > workspace_bytes) { - return false; - } + if (required_ws_bytes > workspace_bytes) { + return false; + } - return true; + return true; } -std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) -{ +std::vector get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only) { + std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; - std::vector simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; - std::vector square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64}; + std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; - std::vector quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, - CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, - CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; - - const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; - return simt_configs_only ? simt_configs : allowed_configs; + const std::vector allowed_configs = is_weight_only ? quant_B_configs : square_configs; + return simt_configs_only ? simt_configs : allowed_configs; } -std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) -{ - std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); +std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only) { + std::vector tiles = get_candidate_tiles(is_weight_only, simt_configs_only); - std::vector candidate_configs; - const int min_stages = 2; - const int max_stages = sm >= 80 ? 4 : 2; + std::vector candidate_configs; + const int min_stages = 2; + const int max_stages = sm >= 80 ? 4 : 2; - for (const auto& tile_config : tiles) { - for (int stages = min_stages; stages <= max_stages; ++stages) { - CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; - candidate_configs.push_back(config); - } + for (const auto& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); } + } - return candidate_configs; + return candidate_configs; } CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t , - const int split_k_limit, - const size_t workspace_bytes, - const int multi_processor_count, - const int is_weight_only) -{ - - if (occupancies.size() != candidate_configs.size()) { - throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and " - "candidate configs vectors must have equal length."); + const std::vector& occupancies, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t, + const int split_k_limit, + const size_t workspace_bytes, + const int multi_processor_count, + const int is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + throw std::runtime_error( + "[FT Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; } - CutlassGemmConfig best_config; - // Score will be [0, 1]. The objective is to minimize this score. - // It represents the fraction of SM resources unused in the last wave. - float config_score = 1.0f; - int config_waves = INT_MAX; - int current_m_tile = 0; - - const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; - for (int ii = 0; ii < candidate_configs.size(); ++ii) { - CutlassGemmConfig candidate_config = candidate_configs[ii]; - TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); - int occupancy = occupancies[ii]; - - if (occupancy == 0) { - continue; - } - - // Keep small tile sizes when possible. - if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile - && current_m_tile < tile_shape.m) { - continue; - } + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } - const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - - for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { - if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { - const int ctas_per_wave = occupancy * multi_processor_count; - const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; - - const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; - const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); - const float current_score = float(num_waves_total) - num_waves_fractional; - - const float score_slack = 0.1f; - if (current_score < config_score - || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { - config_score = current_score; - config_waves = num_waves_total; - SplitKStyle split_style = - split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig{ - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; - current_m_tile = tile_shape.m; - } - else if (current_score == config_score - && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor - || current_m_tile < tile_shape.m)) { - // Prefer deeper pipeline or smaller split-k - SplitKStyle split_style = - split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; - best_config = CutlassGemmConfig{ - candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; - current_m_tile = tile_shape.m; - config_waves = num_waves_total; - } - } + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + const float current_score = float(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = + split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{ + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + } else if (current_score == config_score && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = + split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{ + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; } + } } + } - if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { - throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config."); - } + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config."); + } - return best_config; + return best_config; } } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h index fc566eeea23ad..7c3289495502f 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h @@ -22,21 +22,21 @@ #include #include -//#include "src/fastertransformer/utils/cuda_utils.h" +// #include "src/fastertransformer/utils/cuda_utils.h" namespace fastertransformer { std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only); CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t num_experts, - const int split_k_limit, - const size_t workspace_bytes, - const int multi_processor_count, - const int is_weight_only); + const std::vector& occupancies, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t num_experts, + const int split_k_limit, + const size_t workspace_bytes, + const int multi_processor_count, + const int is_weight_only); } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h index 5ca59c330524c..e43bb4de78c26 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h @@ -26,46 +26,39 @@ namespace cutlass { namespace epilogue { namespace thread { -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; } -__forceinline__ __device__ float tanh_opt(float x) -{ +__forceinline__ __device__ float tanh_opt(float x) { #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - const float exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #else - return fast_tanh(x); + return fast_tanh(x); #endif } - -template<> +template <> struct GELU_taylor { - static const bool kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const - { - - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); - - return float( - cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float( + cutlass::constants::half() * z * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { + return this->operator()(scalar); + } }; } // namespace thread @@ -84,56 +77,56 @@ struct EpilogueOpBias {}; struct EpilogueOpNoBias {}; -template +template struct Epilogue { }; -template +template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationSilu; + using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; -template +template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationRelu; + using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; -template +template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; }; -template +template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; + using Op = cutlass::epilogue::thread::LinearCombination; }; -template +template struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; + using Op = cutlass::epilogue::thread::LinearCombination; }; } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h index 2ac736ba06639..ec752756d7a9f 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h @@ -20,39 +20,39 @@ namespace fastertransformer { // Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape // in the kernel layout details when doing weight only quantization. enum class CutlassTileConfig { - // Signals that we should run heuristics do choose a config - Undefined, + // Signals that we should run heuristics do choose a config + Undefined, - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, - // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape128x32x64 + // Warp configs for M=128 + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape128x32x64 }; enum class SplitKStyle { - NO_SPLIT_K, - SPLIT_K_SERIAL, - // SPLIT_K_PARALLEL // Not supported yet + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet }; struct CutlassGemmConfig { - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; }; } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h index 495a3c327d565..c003d5a8ea850 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h @@ -49,34 +49,30 @@ namespace gemm { namespace kernel { /// Visitor class to abstract away the algorithm for iterating over tiles -template -struct GemmMoeProblemVisitor: - public MoeProblemVisitor, - ThreadblockShape, - GroupScheduleMode_, - PrefetchTileCount, - ThreadCount> { +template +struct GemmMoeProblemVisitor : public MoeProblemVisitor, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + static bool const kTransposed = Transposed; - static bool const kTransposed = Transposed; + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = + MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; - using Base = - MoeProblemVisitor; - using Params = typename Base::Params; - using SharedStorage = typename Base::SharedStorage; - - // - // Methods - // - CUTLASS_DEVICE - GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx): - Base(params_, shared_storage_, block_idx) - { - } + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) : Base(params_, shared_storage_, block_idx) { + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h index 6b58a424d2cd1..ff5e4e4c8b62c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h @@ -21,47 +21,47 @@ namespace cutlass { namespace gemm { namespace kernel { -template +template struct LayoutDetailsB { }; // Volta specialiations. Volta will dequantize before STS, so we need a different operator -template +template struct LayoutDetailsB { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; }; // Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. // TODO - Switch this to column major for weights since gemms should be more performant. -template +template struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; }; -template +template struct MixedGemmArchTraits { }; -template +template struct MixedGemmArchTraits { - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::RowMajor; - - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; }; // ========================= Volta Traits =========================== @@ -69,83 +69,80 @@ struct MixedGemmArchTraits { // This will instantiate any HMMA tensorcore kernels for Volta. // Note that volta does not have native bfloat support so weights and activations will be casted to fp16 // and compute will happen in fp16 then will be converted for bf16 output. -template +template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm70, - typename cutlass::platform::enable_if::value - || cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; + typename cutlass::platform::enable_if::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; // ======================= Turing Traits ============================== // Note that turing does not have native bfloat support so weights and activations will be casted to fp16 // and compute will happen in fp16 then will be converted for bf16 output. -template +template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm75, - typename cutlass::platform::enable_if::value - || cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; + typename cutlass::platform::enable_if::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; // ======================= Ampere Traits ============================== -template +template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm80, - typename cutlass::platform::enable_if::value - || cutlass::platform::is_same::value>::type> { -private: - using LayoutDetails = LayoutDetailsB; + typename cutlass::platform::enable_if::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - using Operator = typename LayoutDetails::Operator; + using Operator = typename LayoutDetails::Operator; }; } // namespace kernel diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h index 22a158e04ddc2..f2f9697c9c4b5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -51,476 +51,453 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// // This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. // It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. -template +template using void_t = void; -template -struct use_dq_gemm: platform::false_type { +template +struct use_dq_gemm : platform::false_type { }; -template -struct use_dq_gemm>: platform::true_type { +template +struct use_dq_gemm> : platform::true_type { }; // SFINAE overload for dequantizing gemm -template::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, - int gemm_k_iterations, - typename Mma::FragmentC& accum, - typename Mma::IteratorA iterator_A, - typename Mma::IteratorB iterator_B, +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, + int gemm_k_iterations, + typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, + typename Mma::IteratorB iterator_B, typename Mma::FragmentC const& src_accum, - ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, - const int thread_idx, - MatrixCoord tb_offset_scale) -{ - - typename Mma::IteratorScale iterator_scale( - Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); - - mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); + ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, + const int thread_idx, + MatrixCoord tb_offset_scale) { + typename Mma::IteratorScale iterator_scale( + Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); } // SFINAE overload for normal gemm. This completely ignores the scale parameters -template::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, - int gemm_k_iterations, - typename Mma::FragmentC& accum, - typename Mma::IteratorA iterator_A, - typename Mma::IteratorB iterator_B, +template ::value, bool>::type = true> +CUTLASS_DEVICE static void run_mma(Mma mma, + int gemm_k_iterations, + typename Mma::FragmentC& accum, + typename Mma::IteratorA iterator_A, + typename Mma::IteratorB iterator_B, typename Mma::FragmentC const& src_accum, - ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, - const int thread_idx, - MatrixCoord tb_offset_scale) -{ - mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); + ElementScale* weight_scale_ptr, + MatrixCoord scale_extent, + const int thread_idx, + MatrixCoord tb_offset_scale) { + mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); } ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct MoeFCGemm { -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = false; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - static_assert(!kTransposed, "Transpose problem not supported"); - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor = - GemmMoeProblemVisitor; - + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { // - // Structures + // Data members // - /// Argument structure - struct Arguments { + int problem_count; + int threadblock_count; - // - // Data members - // + typename EpilogueOutputOp::Params output_op; - int problem_count; - int threadblock_count; + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; - typename EpilogueOutputOp::Params output_op; + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t* total_rows_before_expert; - int64_t gemm_n; - int64_t gemm_k; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments(): - problem_count(0), - threadblock_count(0), - ptr_A(nullptr), - ptr_B(nullptr), - weight_scales(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - total_rows_before_expert(nullptr), - gemm_n(0), - gemm_k(0), - host_problem_sizes(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(int problem_count, - int threadblock_count, - typename EpilogueOutputOp::Params output_op, - const ElementA* ptr_A, - const ElementB* ptr_B, - const ElementScale* weight_scales, - const ElementC* ptr_C, - ElementC* ptr_D, - int64_t* total_rows_before_expert, - int64_t gemm_n, - int64_t gemm_k, - GemmCoord* host_problem_sizes = nullptr): - problem_count(problem_count), - threadblock_count(threadblock_count), - output_op(output_op), - ptr_A(const_cast(ptr_A)), - ptr_B(const_cast(ptr_B)), - weight_scales(const_cast(weight_scales)), - ptr_C(const_cast(ptr_C)), - ptr_D(ptr_D), - total_rows_before_expert(total_rows_before_expert), - gemm_n(gemm_n), - gemm_k(gemm_k), - host_problem_sizes(nullptr) - { - if (platform::is_same::value || platform::is_same::value) { - assert(weight_scales); - } - } - }; + // Only used by device-level operator + GemmCoord* host_problem_sizes; // - // Structure for precomputing values in host memory and passing to kernels + // Methods // - /// Parameters structure - struct Params { + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + const ElementA* ptr_A, + const ElementB* ptr_B, + const ElementScale* weight_scales, + const ElementC* ptr_C, + ElementC* ptr_D, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) : problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; + // + // Structure for precomputing values in host memory and passing to kernels + // - typename EpilogueOutputOp::Params output_op; + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; + typename EpilogueOutputOp::Params output_op; - // - // Methods - // + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; - CUTLASS_HOST_DEVICE - Params(): ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0): - problem_visitor( - args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count), - threadblock_count(args.threadblock_count), - output_op(args.output_op), - ptr_A(args.ptr_A), - ptr_B(args.ptr_B), - weight_scales(args.weight_scales), - ptr_C(args.ptr_C), - ptr_D(args.ptr_D) - { - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = typename ProblemVisitor::Params( - args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - weight_scales = args.weight_scales; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename ProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - -public: // // Methods // - CUTLASS_DEVICE - MoeFCGemm() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) : problem_visitor( + args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) { } - static Status can_implement(Arguments const& args) - { - if (platform::is_same::value || platform::is_same::value) { - if (args.weight_scales == nullptr) { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); - return Status::kInvalid; - } - } - else if (args.weight_scales != nullptr) { - CUTLASS_TRACE_HOST( - "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); - return Status::kInvalid; - } - return Status::kSuccess; + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params( + args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + if (platform::is_same::value || platform::is_same::value) { + if (args.weight_scales == nullptr) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } else if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // The dummy template parameter is not used and exists so that we can compile this code using + // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in + // a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + CUTLASS_NOT_IMPLEMENTED(); } + }; - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = + problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), + ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; - return 0; - } + // + // Matrix multiply phase + // - // The dummy template parameter is not used and exists so that we can compile this code using - // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in - // a namespace - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) - { - CUTLASS_NOT_IMPLEMENTED(); - } - }; - - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) - { - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Load element pointers. Exchange pointers and strides if working on the transpose - const int64_t rows_to_jump = - problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value ? gemm_n : gemm_k * kInterleave; - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; - - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B(LayoutB(ldm_B), - ptr_B, - {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, - thread_idx, - tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - run_mma(mma, - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators, - weight_scale_ptr, - {1, problem_size.n()}, - thread_idx, - tb_offset_scale); - - // - // Epilogue - // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - EpilogueOutputOp output_op(params.output_op); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; - ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); - LayoutC layout_C(0); - LayoutC layout_D(gemm_n); + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + run_mma(mma, + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators, + weight_scale_ptr, + {1, problem_size.n()}, + thread_idx, + tb_offset_scale); - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); + // + // Epilogue + // - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); + EpilogueOutputOp output_op(params.output_op); - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); - // Next tile - problem_visitor.advance(gridDim.x); - } - } - }; + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); + static constexpr bool compile_needed = platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); #else - CUTLASS_NOT_IMPLEMENTED(); + CUTLASS_NOT_IMPLEMENTED(); #endif - } + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index b913c51fc2fd2..eacc32f0e0668 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -22,78 +22,78 @@ namespace fastertransformer { enum class ActivationType { - Gelu, - Relu, - Silu, - GeGLU, - ReGLU, - SiGLU, - Identity, - InvalidType + Gelu, + Relu, + Silu, + GeGLU, + ReGLU, + SiGLU, + Identity, + InvalidType }; -template +template class MoeGemmRunner { -public: - MoeGemmRunner(); + public: + MoeGemmRunner(); - void moe_gemm_bias_act(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - ActivationType activation_type, - cudaStream_t stream); + void moe_gemm_bias_act(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + ActivationType activation_type, + cudaStream_t stream); - void moe_gemm(const T* A, - const WeightType* B, - const T* weight_scales, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - cudaStream_t stream); + void moe_gemm(const T* A, + const WeightType* B, + const T* weight_scales, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + cudaStream_t stream); -private: - template - void dispatch_to_arch(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - cudaStream_t stream, - int* occupancy = nullptr); + private: + template + void dispatch_to_arch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy = nullptr); - template - void run_gemm(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - cudaStream_t stream); + template + void run_gemm(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + cudaStream_t stream); -private: - int sm_; - int multi_processor_count_; + private: + int sm_; + int multi_processor_count_; }; } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index c68ad97bb968c..be3acdf7f53e0 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -46,224 +46,214 @@ namespace fastertransformer { -inline int getSMVersion() -{ - int device{-1}; - cudaGetDevice(&device); - int sm_major = 0; - int sm_minor = 0; - cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); - cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); - return sm_major * 10 + sm_minor; +inline int getSMVersion() { + int device{-1}; + cudaGetDevice(&device); + int sm_major = 0; + int sm_minor = 0; + cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device); + return sm_major * 10 + sm_minor; } // ============================= Variable batched Gemm things =========================== -template -void generic_moe_gemm_kernelLauncher(const T* A, +template +void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, - const int multi_processor_count, - cudaStream_t stream, - int* kernel_occupancy = nullptr) -{ - if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { - throw std::runtime_error("[FT Error][MoeGemm] Grouped gemm does not support split-k"); - } - - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - ""); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; - using ElementType = ElementType_; - - - using CutlassWeightType_ = typename cutlass::platform:: - conditional::value, cutlass::half_t, WeightType>::type; - using CutlassWeightType = CutlassWeightType_; - - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = - typename Epilogue::Op; - - // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< - ElementType, - cutlass::layout::RowMajor, - cutlass::ComplexTransform::kNone, - MixedGemmArchTraits::ElementsPerAccessA, - CutlassWeightType, - typename MixedGemmArchTraits::LayoutB, - cutlass::ComplexTransform::kNone, - MixedGemmArchTraits::ElementsPerAccessB, - ElementType, - cutlass::layout::RowMajor, - ElementAccumulator, - typename MixedGemmArchTraits::OperatorClass, - arch, - ThreadblockShape, - WarpShape, - typename MixedGemmArchTraits::InstructionShape, - EpilogueOp, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, - Stages, - cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, - typename MixedGemmArchTraits::Operator>::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; - - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - - if (kernel_occupancy != nullptr) { - *kernel_occupancy = compute_occupancy_for_kernel(); - return; - } - int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - if (occupancy == 0) { - throw std::runtime_error( - "[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); - } - const int threadblock_count = multi_processor_count * occupancy; - - typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); - - typename GemmGrouped::Arguments args(num_experts, - threadblock_count, - epilogue_op, - reinterpret_cast(A), - reinterpret_cast(B), - reinterpret_cast(weight_scales), - reinterpret_cast(biases), - reinterpret_cast(C), - total_rows_before_expert, - gemm_n, - gemm_k); - - GemmGrouped gemm; - - auto can_implement = gemm.can_implement(args); - if (can_implement != cutlass::Status::kSuccess) { - std::string err_msg = - "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); - throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); - } - - auto init_status = gemm.initialize(args); - if (init_status != cutlass::Status::kSuccess) { - std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " - + std::string(cutlassGetStatusString(init_status)); - throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); - } - - auto run_status = gemm.run(stream); - if (run_status != cutlass::Status::kSuccess) { - std::string err_msg = - "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); - throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); - } + const int multi_processor_count, + cudaStream_t stream, + int* kernel_occupancy = nullptr) { + if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { + throw std::runtime_error("[FT Error][MoeGemm] Grouped gemm does not support split-k"); + } + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType_ = + typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = typename cutlass::platform:: + conditional::value, cutlass::half_t, WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = + typename Epilogue::Op; + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped< + ElementType, + cutlass::layout::RowMajor, + cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, + typename MixedGemmArchTraits::LayoutB, + cutlass::ComplexTransform::kNone, + MixedGemmArchTraits::ElementsPerAccessB, + ElementType, + cutlass::layout::RowMajor, + ElementAccumulator, + typename MixedGemmArchTraits::OperatorClass, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + Stages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) { + *kernel_occupancy = compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + if (occupancy == 0) { + throw std::runtime_error( + "[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); + } + const int threadblock_count = multi_processor_count * occupancy; + + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + + typename GemmGrouped::Arguments args(num_experts, + threadblock_count, + epilogue_op, + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(weight_scales), + reinterpret_cast(biases), + reinterpret_cast(C), + total_rows_before_expert, + gemm_n, + gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "MoEFC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to run cutlass variable batched gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[FT Error][MoE Runner] " + err_msg); + } } -template +template struct dispatch_stages { - static void dispatch(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) - { - std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " - + std::to_string(arch::kMinComputeCapability) + " with stages set to " - + std::to_string(Stages); - throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg); - } + static void dispatch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + throw std::runtime_error("[FT Error][dispatch_stages::dispatch] " + err_msg); + } }; -template +template struct dispatch_stages { - static void dispatch(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) - { - generic_moe_gemm_kernelLauncher( - A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - } + static void dispatch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher( + A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + } }; -template +template struct dispatch_stages 2)>::type> { - static void dispatch(const T* A, - const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) - { - generic_moe_gemm_kernelLauncher(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - } + static void dispatch(const T* A, + const WeightType* B, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + CutlassGemmConfig gemm_config, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + generic_moe_gemm_kernelLauncher(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + } }; -template -void dispatch_gemm_config(const T* A, +template +void dispatch_gemm_config(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.stages) { - case 2: - using DispatcherStages2 = dispatch_stages; - DispatcherStages2::dispatch(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case 3: - using DispatcherStages3 = dispatch_stages; - DispatcherStages3::dispatch(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case 4: - using DispatcherStages4 = dispatch_stages; - DispatcherStages4::dispatch(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - default: - std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); - throw std::runtime_error("[FT Error][MoE][dispatch_gemm_config] " + err_msg); - break; - } + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + throw std::runtime_error("[FT Error][MoE][dispatch_gemm_config] " + err_msg); + break; + } } // This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. // This overload is only enabled when T == WeightType. -template::value && std::is_same::value>::type* = nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, +template ::value && std::is_same::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 64, 64>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::Undefined: - throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); - break; - case CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); - break; - default: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); - break; - } + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 64, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + break; + } } // Tensorop GEMM overload // Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve // compile time -template< +template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && !std::is_same::value>::type* = nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) { - case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<32, 32, 64>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 32, 64>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatch_gemm_config, - cutlass::gemm::GemmShape<128, 32, 64>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::Undefined: - throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); - break; - case CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); - break; - default: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); - break; - } + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for mixed type tensorop GEMM."); + break; + } } // This overload will handle simt gemms. It is disabled via SFINAE for tensorop. -template::value>::type* = nullptr> -void dispatch_moe_gemm_to_cutlass(const T* A, +template ::value>::type* = nullptr> +void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, - int sm_version, - int multi_processor_count, - cudaStream_t stream, - int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) { - case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - dispatch_gemm_config, - cutlass::gemm::GemmShape<64, 64, 8>>(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - gemm_n, - gemm_k, - num_experts, - gemm_config, - multi_processor_count, - stream, - occupancy); - break; - case CutlassTileConfig::Undefined: - throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); - break; - case CutlassTileConfig::ChooseWithHeuristic: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by heuristic."); - break; - default: - throw std::runtime_error( - "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); - break; - } + int sm_version, + int multi_processor_count, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 8>>(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + gemm_n, + gemm_k, + num_experts, + gemm_config, + multi_processor_count, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error("[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] gemm config should have already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[FT Error][dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config for float MoE gemm."); + break; + } } -template -MoeGemmRunner::MoeGemmRunner() -{ - - int device{-1}; - cudaGetDevice(&device); - sm_ = getSMVersion(); - cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); +template +MoeGemmRunner::MoeGemmRunner() { + int device{-1}; + cudaGetDevice(&device); + sm_ = getSMVersion(); + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device); } -template -template -void MoeGemmRunner::dispatch_to_arch(const T* A, +template +template +void MoeGemmRunner::dispatch_to_arch(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, - cudaStream_t stream, - int* occupancy) -{ - - - if (sm_ >= 70 && sm_ < 75) { - dispatch_moe_gemm_to_cutlass(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - sm_, - multi_processor_count_, - stream, - occupancy); - } - else if (sm_ >= 75 && sm_ < 80) { - dispatch_moe_gemm_to_cutlass(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - sm_, - multi_processor_count_, - stream, - occupancy); - } - else if (sm_ >= 80 && sm_ < 90) { - dispatch_moe_gemm_to_cutlass(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - gemm_config, - sm_, - multi_processor_count_, - stream, - occupancy); - } - else { - throw std::runtime_error("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); - } + cudaStream_t stream, + int* occupancy) { + if (sm_ >= 70 && sm_ < 75) { + dispatch_moe_gemm_to_cutlass(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); + } else if (sm_ >= 75 && sm_ < 80) { + dispatch_moe_gemm_to_cutlass(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); + } else if (sm_ >= 80 && sm_ < 90) { + dispatch_moe_gemm_to_cutlass(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + gemm_config, + sm_, + multi_processor_count_, + stream, + occupancy); + } else { + throw std::runtime_error("[FT Error][MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); + } } -template -template -void MoeGemmRunner::run_gemm(const T* A, +template +template +void MoeGemmRunner::run_gemm(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - cudaStream_t stream) -{ - static constexpr bool is_weight_only = !std::is_same::value; - static constexpr bool only_simt_configs = std::is_same::value; - std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); - std::vector occupancies(candidate_configs.size()); - - for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { - dispatch_to_arch(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - candidate_configs[ii], - stream, - &occupancies[ii]); - } - - static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. - static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. - CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs, - occupancies, - total_rows, - gemm_n, - gemm_k, - num_experts, - split_k_limit, - workspace_bytes, - multi_processor_count_, - is_weight_only); - + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + cudaStream_t stream) { + static constexpr bool is_weight_only = !std::is_same::value; + static constexpr bool only_simt_configs = std::is_same::value; + std::vector candidate_configs = get_candidate_configs(sm_, is_weight_only, only_simt_configs); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { dispatch_to_arch(A, B, weight_scales, @@ -781,66 +729,67 @@ void MoeGemmRunner::run_gemm(const T* A, gemm_n, gemm_k, num_experts, - chosen_config, - stream); + candidate_configs[ii], + stream, + &occupancies[ii]); + } + + static constexpr int workspace_bytes = 0; // No workspace for MoE GEMMs. + static constexpr int split_k_limit = 1; // MoE GEMM does not support split-k. + CutlassGemmConfig chosen_config = estimate_best_config_from_occupancies(candidate_configs, + occupancies, + total_rows, + gemm_n, + gemm_k, + num_experts, + split_k_limit, + workspace_bytes, + multi_processor_count_, + is_weight_only); + + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + chosen_config, + stream); } -template -void MoeGemmRunner::moe_gemm_bias_act(const T* A, +template +void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightType* B, - const T* weight_scales, - const T* biases, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - ActivationType activation_type, - cudaStream_t stream) -{ - switch (activation_type) { - case ActivationType::Relu: - run_gemm(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); - break; - case ActivationType::Gelu: - run_gemm(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); - break; - case ActivationType::Silu: - run_gemm(A, - B, - weight_scales, - biases, - C, - total_rows_before_expert, - total_rows, - gemm_n, - gemm_k, - num_experts, - stream); - break; - case ActivationType::Identity: - run_gemm(A, + const T* weight_scales, + const T* biases, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + ActivationType activation_type, + cudaStream_t stream) { + switch (activation_type) { + case ActivationType::Relu: + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + stream); + break; + case ActivationType::Gelu: + run_gemm(A, B, weight_scales, biases, @@ -851,30 +800,55 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, gemm_k, num_experts, stream); - break; - case ActivationType::InvalidType: - std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); - break; - default: { - std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); - } + break; + case ActivationType::Silu: + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + stream); + break; + case ActivationType::Identity: + run_gemm(A, + B, + weight_scales, + biases, + C, + total_rows_before_expert, + total_rows, + gemm_n, + gemm_k, + num_experts, + stream); + break; + case ActivationType::InvalidType: + std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + break; + default: { + std::runtime_error("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); } + } } -template -void MoeGemmRunner::moe_gemm(const T* A, +template +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, - const T* weight_scales, - T* C, - int64_t* total_rows_before_expert, - int64_t total_rows, - int64_t gemm_n, - int64_t gemm_k, - int num_experts, - cudaStream_t stream) -{ - run_gemm( - A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + const T* weight_scales, + T* C, + int64_t* total_rows_before_expert, + int64_t total_rows, + int64_t gemm_n, + int64_t gemm_k, + int num_experts, + cudaStream_t stream) { + run_gemm( + A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); } } // namespace fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 44d5d3ea900e4..65060975520a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -21,10 +21,9 @@ namespace fastertransformer { -static inline size_t pad_to_multiple_of_16(const size_t& input) -{ - static constexpr int ALIGNMENT = 16; - return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +static inline size_t pad_to_multiple_of_16(const size_t& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); } /* @@ -43,189 +42,188 @@ static inline size_t pad_to_multiple_of_16(const size_t& input) num_experts - The number of expert layers present k - k value in topk */ -template -void topk_gating_softmax_kernelLauncher(const T* input, - const bool* finished, - T* output, - T* softmax_temp_out, - int* indices, - int* source_row, - const int num_rows, - const int num_experts, - const int k, +template +void topk_gating_softmax_kernelLauncher(const T* input, + const bool* finished, + T* output, + T* softmax_temp_out, + int* indices, + int* source_row, + const int num_rows, + const int num_experts, + const int k, cudaStream_t stream); class CubKeyValueSorter { -public: - CubKeyValueSorter(); + public: + CubKeyValueSorter(); - CubKeyValueSorter(const int num_experts); + CubKeyValueSorter(const int num_experts); - void update_num_experts(const int num_experts); + void update_num_experts(const int num_experts); - size_t getWorkspaceSize(const size_t num_key_value_pairs); + size_t getWorkspaceSize(const size_t num_key_value_pairs); - void run(void* workspace, - const size_t workspace_size, - const int* keys_in, - int* keys_out, - const int* values_in, - int* values_out, - const size_t num_key_value_pairs, - cudaStream_t stream); + void run(void* workspace, + const size_t workspace_size, + const int* keys_in, + int* keys_out, + const int* values_in, + int* values_out, + const size_t num_key_value_pairs, + cudaStream_t stream); -private: - size_t num_key_value_pairs_; - int num_experts_; - int num_bits_; + private: + size_t num_key_value_pairs_; + int num_experts_; + int num_bits_; }; -template -void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, - T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, - const int num_rows, - const int active_rows, - const int cols, - const int k, +template +void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, + T* permuted_output, + const int* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + const int num_rows, + const int active_rows, + const int cols, + const int k, cudaStream_t stream); -template -void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, - T* reduced_unpermuted_output, - const T* bias, - const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, - const int num_rows, - const int cols, - const int k, +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int num_rows, + const int cols, + const int k, cudaStream_t stream); -template -void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, - T* reduced_unpermuted_output, - const T* skip, - const T* bias, - const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, - const int num_rows, - const int cols, - const int k, +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int num_rows, + const int cols, + const int k, cudaStream_t stream); -template -void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, - T* reduced_unpermuted_output, - const T* skip_1, - const T* skip_2, - const T* bias, - const T* scales, - const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, - const int num_rows, - const int cols, - const int k, +template +void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, + T* reduced_unpermuted_output, + const T* skip_1, + const T* skip_2, + const T* bias, + const T* scales, + const int* expanded_source_row_to_expanded_dest_row, + const int* expert_for_source_row, + const int num_rows, + const int cols, + const int k, cudaStream_t stream); // Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc . // Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive. // Avoid making several duplicates of this class. -template +template class CutlassMoeFCRunner { -public: - CutlassMoeFCRunner(); - - size_t getWorkspaceSize( - const int num_rows, const int hidden_size, const int inter_size, const int num_experts, const int k); - - void run_moe_fc(const T* input_activations, - const T* gating_output, - const WeightType* fc1_expert_weights, - const T* fc1_scales, - const T* fc1_expert_biases, - ActivationType fc1_activation_type, - const WeightType* fc2_expert_weights, - const T* fc2_scales, - const int num_rows, - const int hidden_size, - const int inter_size, - const int num_experts, - const int k, - char* workspace_ptr, - T* fc2_result, - T* expert_scales, - int* expanded_source_row_to_expanded_dest_row, - int* expert_for_source_row, - cudaStream_t stream); - - void run_moe_fc(const T* input_activations, - const T* gating_output, - const WeightType* fc1_expert_weights, - const T* fc1_scales, - const T* fc1_expert_biases, - ActivationType fc1_activation_type, - const WeightType* fc2_expert_weights, - const T* fc2_scales, - const int num_rows, - const int hidden_size, - const int inter_size, - const int num_experts, - const int k, - char* workspace_ptr, - T* fc2_result, - const bool* finished, - const int active_rows, - T* expert_scales, - int* expanded_source_row_to_expanded_dest_row, - int* expert_for_source_row, - cudaStream_t stream); - - void compute_total_rows_before_expert(const int* sorted_indices, - const int total_indices, - const int num_experts, - int64_t* total_rows_before_expert, - cudaStream_t stream); - -private: - void configure_ws_ptrs(char* ws_ptr, - const int num_rows, - const int hidden_size, - const int inter_size, - const int num_experts, - const int k); - -private: - CubKeyValueSorter sorter_; - MoeGemmRunner moe_gemm_runner_; - - // Pointers - int* source_rows_; - int* permuted_rows_; - int* permuted_experts_; - char* sorter_ws_; - T* permuted_data_; - T* softmax_out_; - - int64_t* total_rows_before_expert_; - - T* fc1_result_; + public: + CutlassMoeFCRunner(); + + size_t getWorkspaceSize( + const int num_rows, const int hidden_size, const int inter_size, const int num_experts, const int k); + + void run_moe_fc(const T* input_activations, + const T* gating_output, + const WeightType* fc1_expert_weights, + const T* fc1_scales, + const T* fc1_expert_biases, + ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, + const T* fc2_scales, + const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + char* workspace_ptr, + T* fc2_result, + T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, + cudaStream_t stream); + + void run_moe_fc(const T* input_activations, + const T* gating_output, + const WeightType* fc1_expert_weights, + const T* fc1_scales, + const T* fc1_expert_biases, + ActivationType fc1_activation_type, + const WeightType* fc2_expert_weights, + const T* fc2_scales, + const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k, + char* workspace_ptr, + T* fc2_result, + const bool* finished, + const int active_rows, + T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, + cudaStream_t stream); + + void compute_total_rows_before_expert(const int* sorted_indices, + const int total_indices, + const int num_experts, + int64_t* total_rows_before_expert, + cudaStream_t stream); + + private: + void configure_ws_ptrs(char* ws_ptr, + const int num_rows, + const int hidden_size, + const int inter_size, + const int num_experts, + const int k); + + private: + CubKeyValueSorter sorter_; + MoeGemmRunner moe_gemm_runner_; + + // Pointers + int* source_rows_; + int* permuted_rows_; + int* permuted_experts_; + char* sorter_ws_; + T* permuted_data_; + T* softmax_out_; + + int64_t* total_rows_before_expert_; + + T* fc1_result_; }; -template +template class CutlassMoeFCRunner::value>> { -public: - CutlassMoeFCRunner() = default; - - size_t getWorkspaceSize( - const int num_rows, const int hidden_size, const int inter_size, const int num_experts, const int k) - { - return 0; - } + public: + CutlassMoeFCRunner() = default; + + size_t getWorkspaceSize( + const int num_rows, const int hidden_size, const int inter_size, const int num_experts, const int k) { + return 0; + } }; } // namespace fastertransformer \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h index 88e6ebe743cc5..655a3b6d9041e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -46,293 +46,269 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Visitor class to abstract away the algorithm for iterating over tiles -template +template struct BaseMoeProblemVisitor { - using ThreadblockShape = ThreadblockShape_; + using ThreadblockShape = ThreadblockShape_; - struct ProblemInfo { - static int32_t const kNoPrefetchEntry = -1; - int32_t problem_idx; - int32_t problem_start; + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; - CUTLASS_DEVICE - ProblemInfo(): problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} - - CUTLASS_DEVICE - ProblemInfo(int32_t problem_idx_, int32_t problem_start_): - problem_idx(problem_idx_), problem_start(problem_start_) - { - } - }; - - struct Params { - int64_t const* last_row_for_problem; - int64_t gemm_n; - int64_t gemm_k; - int32_t problem_count; - void const* workspace; - int32_t tile_count; - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - Params(): - last_row_for_problem(nullptr), gemm_n(0), gemm_k(0), problem_count(0), workspace(nullptr), tile_count(0) - { - } + CUTLASS_DEVICE + ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} - /// Ctor - CUTLASS_HOST_DEVICE - Params(int64_t const* last_row_for_problem, - int64_t gemm_n, - int64_t gemm_k, - int32_t problem_count, - void const* workspace = nullptr, - int32_t tile_count = 0): - last_row_for_problem(last_row_for_problem), - gemm_n(gemm_n), - gemm_k(gemm_k), - problem_count(problem_count), - workspace(workspace), - tile_count(tile_count) - { - } - }; + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) : problem_idx(problem_idx_), problem_start(problem_start_) { + } + }; - Params const& params; - int32_t tile_idx; - int32_t problem_tile_start; - int32_t problem_idx; + struct Params { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; // // Methods // - CUTLASS_DEVICE - BaseMoeProblemVisitor(Params const& params_, int32_t block_idx): - params(params_), tile_idx(block_idx), problem_tile_start(0), problem_idx(0) - { - } - - /// Get the grid shape - CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) - { - - return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), - 1); - } - - /// Gets the global tile index - CUTLASS_HOST_DEVICE - int32_t tile_index() const - { - return tile_idx; - } - - /// Gets the index of the problem - CUTLASS_HOST_DEVICE - int32_t problem_index() const - { - return problem_idx; - } - - CUTLASS_HOST_DEVICE - int32_t threadblock_idx() const - { - return tile_idx - problem_tile_start; - } - - CUTLASS_DEVICE - void advance(int32_t grid_size) - { - tile_idx += grid_size; - } - - CUTLASS_HOST_DEVICE - static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) - { - ProblemSizeHelper::possibly_transpose_problem(problem); - } - /// Returns the problem size for the current problem + /// Ctor CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size() const - { - return problem_size(problem_idx); + Params() : last_row_for_problem(nullptr), gemm_n(0), gemm_k(0), problem_count(0), workspace(nullptr), tile_count(0) { } + /// Ctor CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size(int idx) const - { - const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; - const int64_t current_problem_row = params.last_row_for_problem[idx]; - const int64_t gemm_m = current_problem_row - prev_problem_row; - GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); - ProblemSizeHelper::possibly_transpose_problem(problem); - return problem; + Params(int64_t const* last_row_for_problem, + int64_t gemm_n, + int64_t gemm_k, + int32_t problem_count, + void const* workspace = nullptr, + int32_t tile_count = 0) : last_row_for_problem(last_row_for_problem), + gemm_n(gemm_n), + gemm_k(gemm_k), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) { } - - CUTLASS_HOST_DEVICE - static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) - { - return ProblemSizeHelper::tile_count(grid); + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) : params(params_), tile_idx(block_idx), problem_tile_start(0), problem_idx(0) { + } + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), + 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); } - static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) - { - int32_t total_tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) { - auto problem = host_problem_sizes_ptr[i]; - possibly_transpose_problem(problem); - auto grid = grid_shape(problem); - total_tiles += tile_count(grid); - } - - return total_tiles; - } + return total_tiles; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct MoeProblemVisitor; ///////////////////////////////////////////////////////////////////////////////////////////////// // ProblemVisitor that performs all scheduling on device // -template +template struct MoeProblemVisitor: public BaseMoeProblemVisitor { - using Base = BaseMoeProblemVisitor; - using Params = typename Base::Params; - static int const kThreadCount = ThreadCount; - static bool const kRequiresPrecomputation = false; - static int const kThreadsPerWarp = 32; - - struct SharedStorage {}; - - // Final tile of the problem loaded by this thread. Each thread will hold - // a separate value. - int32_t problem_ending_tile; - - SharedStorage& shared_storage; - - // - // Methods - // - CUTLASS_DEVICE - MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx): - Base(params_, block_idx), problem_ending_tile(0), shared_storage(shared_storage_) - { - this->problem_idx = -1 * kThreadsPerWarp; - this->problem_tile_start = 0; + ThreadCount> : public BaseMoeProblemVisitor { + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage {}; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) : Base(params_, block_idx), problem_ending_tile(0), shared_storage(shared_storage_) { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; } - CUTLASS_DEVICE - bool next_tile() - { - // Check whether the tile to compute is within the range of the current problem. - int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); - if (this->tile_idx < problem_tile_end) { - return true; + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; } + } - // Check whether the tile to compute is within the current group of problems fetched by the warp. - // The last tile for this group is the final tile of the problem held by the final thread in the warp. - int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - // Keep the starting problem for this group in `problem_idx`. This is done to reduce - // register pressure. The starting problem for this group is simply the first problem - // in the group most recently fetched by the warp. - int32_t& group_problem_start = this->problem_idx; - group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; - - // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce - // register pressure. - int32_t& group_tile_start = this->problem_tile_start; - - // Each thread in the warp processes a separate problem to advance until - // reaching a problem whose starting tile is less less than tile_idx. - while (group_tile_end <= this->tile_idx) { - group_problem_start += kThreadsPerWarp; - if (group_problem_start > this->params.problem_count) { - return false; - } - - // Since `group_tile_start` is a reference to `this->problem_tile_start`, this - // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` - // is also set here is used later in `next_tile`. - group_tile_start = group_tile_end; - - int lane_idx = threadIdx.x % kThreadsPerWarp; - int32_t lane_problem = group_problem_start + lane_idx; - - // Compute the number of tiles in the problem assigned to each thread. - problem_ending_tile = 0; - if (lane_problem < this->params.problem_count) { - cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - problem_ending_tile = this->tile_count(grid); - } - - // Compute a warp-wide inclusive prefix sum to compute the ending tile index of - // each thread's problem. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kThreadsPerWarp; i <<= 1) { - int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); - if (lane_idx >= i) { - problem_ending_tile += val; - } - } - - // The total tile count for this group is now in the final position of the prefix sum - int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - problem_ending_tile += group_tile_start; - group_tile_end += tiles_in_group; - } + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - // The next problem to process is the first one that does not have ending tile position - // that is greater than or equal to tile index. - int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } - this->problem_idx = group_problem_start + problem_idx_in_group; + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); - // The starting tile for this problem is the ending tile of the previous problem. In cases - // where `problem_idx_in_group` is the first problem in the group, we do not need to reset - // `problem_tile_start`, because it is set to the previous group's ending tile in the while - // loop above. - if (problem_idx_in_group > 0) { - this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); - } + this->problem_idx = group_problem_start + problem_idx_in_group; - return true; + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); } - static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, - int32_t problem_count, - int32_t block_count) - { - return 0; - } + return true; + } - static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, - int32_t problem_count, - int32_t block_count, - void* host_workspace_ptr) - { - } + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) { + } }; } // namespace kernel diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h index bb0808522b19a..3505bea24e4d9 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h @@ -41,20 +41,20 @@ namespace cutlass { namespace layout { -template +template class ColumnMajorTileInterleave { - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; -template +template struct IsColumnMajorTileInterleave { - static constexpr bool value = false; + static constexpr bool value = false; }; -template +template struct IsColumnMajorTileInterleave> { - static constexpr bool value = true; + static constexpr bool value = true; }; } // namespace layout diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 4c7c65fe572f5..272019fb3f910 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -57,12 +57,12 @@ Status MoEBlock::ComputeInternal(OpKernelContext* context) const { size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); - //TODO: check shape + // TODO: check shape AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - //TODO: allocate once and reuse + // TODO: allocate once and reuse IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); IAllocatorUniquePtr expert_scales = IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index dfa5437413685..9480de73b223a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -16,7 +16,7 @@ using namespace onnxruntime::cuda; template class MoEBlock final : public CudaKernel { public: - explicit MoEBlock(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info){ + explicit MoEBlock(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); std::string activation_type_str; diff --git a/onnxruntime/test/python/transformers/test_parity_moe_block.py b/onnxruntime/test/python/transformers/test_parity_moe_block.py index 2e7e9b30e8c9b..a54b66c91d1fe 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe_block.py +++ b/onnxruntime/test/python/transformers/test_parity_moe_block.py @@ -15,12 +15,12 @@ import numpy import numpy as np import onnx -import onnxruntime - import torch import torch.nn as nn import torch.nn.functional as F +import onnxruntime + def create_moe_onnx_graph( num_rows, @@ -129,6 +129,12 @@ def onnx_inference( from onnxruntime import InferenceSession, SessionOptions sess_options = SessionOptions() + + cuda_providers = ["CUDAExecutionProvider"] + if cuda_providers[0] not in onnxruntime.get_available_providers(): + return None + + # TODO: move this to session creation sess_options.log_severity_level = 2 ort_session = InferenceSession(onnx_model_path, sess_options, providers=["CUDAExecutionProvider"]) @@ -267,7 +273,7 @@ def torch_forward(self): x = x * scores x = x.reshape(B, T, C) - #print(x) + # print(x) return x, torch.sum(x) def onnx_forward(self): @@ -282,7 +288,7 @@ def onnx_forward(self): "gated_output": numpy.ascontiguousarray(logits.detach().numpy().astype(numpy.float16)), } ort_output = onnx_inference(self.moe_onnx_graph, ort_inputs) - #print(ort_output) + # print(ort_output) return ort_output