From 6ff31e06d5757779b9c8d53e9d02a3b62b3e3438 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 19 Mar 2024 21:28:15 -0700 Subject: [PATCH] [MoE] Add TP and Mixtral MoE (#19945) ### Description 1.Support Tensor Parallelism in ShardedMoE. 2.Make necessary code changes to support Mixtral MoE. 3.Fix a bug related to using IOBinding in test script. 4.Fix the input size limitation ### Motivation and Context --- docs/ContribOperators.md | 16 +- docs/OperatorKernels.md | 2 +- .../cuda/collective/sharded_moe.cc | 113 ++++-- .../contrib_ops/cuda/collective/sharded_moe.h | 1 + .../cuda/moe/ft_moe/epilogue_helpers.h | 33 +- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 9 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 48 ++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 304 +++++++++++---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 22 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 58 ++- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 50 ++- .../core/graph/contrib_ops/collective_defs.cc | 32 +- .../core/graph/contrib_ops/contrib_defs.cc | 11 +- .../core/providers/cuda/cu_inc/common.cuh | 4 +- onnxruntime/test/contrib_ops/moe_test.cc | 177 ++++++++- .../sharded_moe/test_sharded_moe.py | 260 ++++++++++--- .../transformers/test_parity_mixtral_moe.py | 365 ++++++++++++++++++ .../python/transformers/test_parity_moe.py | 13 +- 18 files changed, 1272 insertions(+), 246 deletions(-) create mode 100644 onnxruntime/test/python/transformers/test_parity_mixtral_moe.py diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5f0100fad95a2..32a4ca16b7824 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2931,8 +2931,8 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MoE** Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, - GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) - usually uses top 32 experts. + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). #### Version @@ -2946,9 +2946,11 @@ This version of the operator has been available since version 1 of the 'com.micr
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
k : int
Number of top experts to select from expert pool
+
normalize_routing_weights : int
+
Whether to normalize routing weights
-#### Inputs (4 - 6) +#### Inputs (5 - 8)
input : T
@@ -2957,12 +2959,16 @@ This version of the operator has been available since version 1 of the 'com.micr
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
3D input tensor with shape (num_experts, hidden_size, inter_size)
-
fc2_experts_weights : T
-
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc1_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_weights : T
+
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
+
fc3_experts_weights (optional) : T
+
3D optional input tensor with shape (num_experts, hidden_size, inter_size)
+
fc3_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index eddc3b7873d80..bca8e17b3dfd4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -861,7 +861,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 40a667ffd5d83..2efc37cf98010 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -35,6 +37,7 @@ using namespace ONNX_NAMESPACE; template ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("tensor_shards", &tensor_shards_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("local_experts_start_index", &local_experts_start_index_).IsOK()); rank_to_experts_start_index_.resize(nccl_->Size()); // Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized. @@ -55,27 +58,36 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { // Create a {Rank, ExpertsStartIndex} map on Host. AutoDestoryCudaEvent cuda_event; cudaEvent_t& copy_event = cuda_event.Get(); - ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc2_experts_weights = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc1_experts_bias_optional = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); const Tensor* fc2_experts_bias_optional = context->Input(5); + const Tensor* fc3_experts_weights_optional = context->Input(6); + const Tensor* fc3_experts_bias_optional = context->Input(7); + + MoEParameters moe_params(tensor_shards_); + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, + fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, + fc3_experts_bias_optional)); - MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, - fc1_experts_bias_optional, fc2_experts_bias_optional)); ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) { + ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); + } + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_); size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), - static_cast(k_)); + moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), + static_cast(moe_params.num_experts), static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); @@ -93,19 +105,25 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr expert_for_source_row = IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); - // fc1_scales and fc2_scales are used in quantized MoE - const CudaT* fc1_scales_ptr = nullptr; - const CudaT* fc2_scales_ptr = nullptr; + const CudaT* fc_scales_ptr = nullptr; moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), reinterpret_cast(router_probs->template Data()), reinterpret_cast(fc1_experts_weights->template Data()), - std::move(fc1_scales_ptr), + std::move(fc_scales_ptr), fc1_experts_bias_optional == nullptr ? nullptr : reinterpret_cast(fc1_experts_bias_optional->template Data()), - activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), + activation_type_, + fc3_experts_weights_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->template Data()), + std::move(fc_scales_ptr), + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->template Data()), + std::move(fc_scales_ptr), static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), @@ -116,31 +134,54 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); - size_t stride_count = moe_params.hidden_size; - size_t stride_bytes = stride_count * sizeof(CudaT); - int64_t total_past_rows = 0; - int64_t total_covered_rows = 0; - if (copy_event != nullptr) { - CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event)); + if (moe_params.parallel_type == MoEParallelType::None) { + fc2_output_bc = std::move(fc2_output); } - NCCL_RETURN_IF_ERROR(ncclGroupStart()); - for (int rank = 0; rank < nccl_->Size(); ++rank) { - int64_t experts_start_index = rank_to_experts_start_index_[rank]; - moe_runner.get_total_rows_info(experts_start_index, - moe_params.local_num_experts, - total_past_rows, - total_covered_rows); - const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; - char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; - NCCL_RETURN_IF_ERROR(ncclBroadcast(src, - dst, - total_covered_rows * stride_count, + + if (moe_params.parallel_type == MoEParallelType::EPAndTP) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet"); + } + + if (moe_params.parallel_type == MoEParallelType::TP) { + ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size()); + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast(fc2_output.get()), + reinterpret_cast(fc2_output_bc.get()), + fc2_output_size / sizeof(CudaT), GetNcclDataType(input->DataType()), - rank, + ncclSum, nccl_->Comm(), Stream(context))); + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + } + + if (moe_params.parallel_type == MoEParallelType::EP) { + size_t stride_count = moe_params.hidden_size; + size_t stride_bytes = stride_count * sizeof(CudaT); + int64_t total_past_rows = 0; + int64_t total_covered_rows = 0; + if (copy_event != nullptr) { + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event)); + } + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + for (int rank = 0; rank < nccl_->Size(); ++rank) { + int64_t experts_start_index = rank_to_experts_start_index_[rank]; + moe_runner.get_total_rows_info(experts_start_index, + moe_params.local_num_experts, + total_past_rows, + total_covered_rows); + const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; + char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; + NCCL_RETURN_IF_ERROR(ncclBroadcast(src, + dst, + total_covered_rows * stride_count, + GetNcclDataType(input->DataType()), + rank, + nccl_->Comm(), + Stream(context))); + } + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); } - NCCL_RETURN_IF_ERROR(ncclGroupEnd()); ort_fastertransformer::finalize_moe_routing_kernelLauncher( reinterpret_cast(fc2_output_bc.get()), reinterpret_cast(output->template MutableData()), diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index 5ea4ae59c4020..827283a794dd6 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -26,6 +26,7 @@ class ShardedMoE final : public NcclKernel, public MoEBase { Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const; int64_t local_experts_start_index_; + int64_t tensor_shards_; std::vector rank_to_experts_start_index_; }; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h index 78d206bf1d9bc..b18a70e899d1c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h @@ -83,10 +83,16 @@ namespace ort_fastertransformer { struct EpilogueOpBiasSilu {}; +struct EpilogueOpNoBiasSilu {}; + struct EpilogueOpBiasReLU {}; +struct EpilogueOpNoBiasReLU {}; + struct EpilogueOpBiasFtGelu {}; +struct EpilogueOpNoBiasFtGelu {}; + struct EpilogueOpBias {}; struct EpilogueOpNoBias {}; @@ -101,6 +107,13 @@ struct Epilogue; }; +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationGeneric< @@ -116,6 +136,14 @@ struct Epilogue; }; +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling, + cutlass::FloatRoundStyle::round_to_nearest, true>; +}; + template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombination struct Epilogue { using Op = - cutlass::epilogue::thread::LinearCombination; + cutlass::epilogue::thread::LinearCombination< + ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; }; } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 60608f462fde5..e0f91ab806c85 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -42,8 +42,13 @@ class MoeGemmRunner { int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream); - void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); + void moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, cudaStream_t stream); + + void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, + int num_experts, cudaStream_t stream); private: template diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index a3dcf0da16b98..2a15fdfd1cc1a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -311,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig template ::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: @@ -429,11 +429,47 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp } template -void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, - int64_t gemm_k, int num_experts, cudaStream_t stream) { - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, +void MoeGemmRunner::moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, + T* C, int64_t* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, cudaStream_t stream) { + switch (activation_type) { + case ActivationType::Relu: + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; + case ActivationType::Gelu: + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; + case ActivationType::Silu: + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); + break; + case ActivationType::Identity: + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + break; + case ActivationType::InvalidType: + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + break; + default: { + ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); + } + } +} + +template +void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, + T* C, int64_t* total_rows_before_expert, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { + if (biases != nullptr) { + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, stream); + } else { + run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); + } } } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index a5b47bcddefbc..5e6e484567988 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -30,7 +30,6 @@ #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" #ifdef __GNUC__ #pragma GCC diagnostic pop @@ -49,15 +48,14 @@ #endif namespace ort_fastertransformer { - static constexpr int WARP_SIZE = 32; // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. template -__launch_bounds__(TPB) __global__ - void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { +__launch_bounds__(TPB) __global__ void moe_softmax(const T* input, const bool* finished, T* output, + const int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -108,14 +106,15 @@ __launch_bounds__(TPB) __global__ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 template -__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) { +__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, int, bool) { // Does not support pre-Kepler architectures ; } #else template __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, - int* indices, int* source_rows, int num_experts, int k) { + int* indices, int* source_rows, int num_experts, int k, + bool normalize_routing_weights) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -128,6 +127,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool should_process_row = finished ? !finished[block_row] : true; const int thread_read_offset = blockIdx.x * num_experts; + float output_row_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities @@ -155,6 +155,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, output[idx] = result_kvp.value; indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; + + if (normalize_routing_weights && k_idx == k - 1) { +#pragma unroll + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + } + } } __syncthreads(); } @@ -178,7 +185,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices, - int* source_rows, int k) { + int* source_rows, int k, bool normalize_routing_weights) { // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); @@ -296,6 +303,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + float output_row_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax float max_val = row_chunk[0]; @@ -336,8 +344,16 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // single) thread per row of the input/output matrices. const int idx = k * thread_row + k_idx; output[idx] = T(max_val); + output_row_sum = output_row_sum + static_cast(max_val); indices[idx] = should_process_row ? expert : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; + + if (normalize_routing_weights && k_idx == k - 1) { +#pragma unroll + for (int ki = 0; ki < k; ++ki) { + output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum); + } + } } // Finally, we clear the value in the thread with the current max if there is another iteration to run. @@ -370,7 +386,8 @@ struct TopkConstants { template void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, - int num_rows, int /*num_experts*/, int k, cudaStream_t stream) { + int num_rows, int /*num_experts*/, int k, bool normalize_routing_weights, + cudaStream_t stream) { static constexpr unsigned long MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); @@ -382,61 +399,63 @@ void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T dim3 block_dim(WARP_SIZE, WARPS_PER_TB); topk_gating_softmax - <<>>(input, finished, output, num_rows, indices, source_row, k); + <<>>(input, finished, output, num_rows, indices, source_row, k, + normalize_routing_weights); } template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, int* indices, int* source_row, int num_rows, int num_experts, - int k, cudaStream_t stream) { + int k, bool normalize_routing_weights, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; switch (num_experts) { case 2: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 4: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 8: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 16: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 32: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 64: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 128: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } case 256: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, - num_experts, k, stream); + num_experts, k, normalize_routing_weights, stream); break; } default: { static constexpr int TPB = 256; moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); moe_top_k - <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k); + <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k, + normalize_routing_weights); } } } @@ -521,25 +540,31 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { - total_past_rows_ = 0; - total_covered_rows_ = 0; +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, + bool has_fc3, + bool normalize_routing_weights) + : has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights) { moe_gemm_runner_.initialize(sm_version); } template -size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, const int hidden_size, - const int inter_size, int num_experts, - int k) { - const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); - const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); - const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); - const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); - int num_softmax_outs = 0; +size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_rows, const size_t hidden_size, + const size_t inter_size, size_t num_experts, + size_t k) { + total_covered_rows_ = k * num_rows; + + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); + size_t num_softmax_outs = 0; const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - num_softmax_outs = static_cast(pad_to_multiple_of_16(num_rows * num_experts)); + num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); } // softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them @@ -548,13 +573,13 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const int bytes_for_fc1_result = interbuf_size * sizeof(T); - const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows))); - sorter_.update_num_experts(num_experts); + const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(static_cast(num_experts)); - int bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; if (sorter_ws_size_bytes > bytes_for_fc1_result) { - int remaining_bytes = static_cast(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result)); + size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result); bytes_for_intermediate_and_sorting += remaining_bytes; } @@ -563,13 +588,13 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, } template -void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, int num_rows, - const int hidden_size, const int inter_size, - int num_experts, int k) { - const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size)); - const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); - const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); - const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); +void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, size_t num_rows, + const size_t hidden_size, const size_t inter_size, + size_t num_experts, size_t k) { + const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size); + const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size); + const size_t padded_experts = pad_to_multiple_of_16(num_experts); + const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); source_rows_ = (int*)ws_ptr; permuted_rows_ = source_rows_ + num_moe_inputs; @@ -578,28 +603,130 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size); - fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts); + if (has_fc3_) { + fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + } else { + fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = (T*)(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); } else { softmax_out_ = nullptr; } } +namespace { + +struct __align__(8) Half4 { + half2 x; + half2 y; +}; + +// TODO(wy): move to common header +template +struct T4; +template <> +struct T4 { + using Type = float4; +}; +template <> +struct T4 { + using Type = Half4; +}; + +template +struct T2; +template <> +struct T2 { + using Type = float2; +}; +template <> +struct T2 { + using Type = half2; +}; + +inline __device__ float2 operator*(const float2 a, const float2 b) { + return make_float2(a.x * b.x, a.y * b.y); +} + +inline __device__ float4 operator*(const float4 a, const float4 b) { + return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +inline __device__ half operator*(const half a, const half b) { + return __float2half(__half2float(a) * __half2float(b)); +} + +inline __device__ half2 operator*(const half2 a, const half2 b) { + return make_half2(a.x * b.x, a.y * b.y); +} +#endif + +inline __device__ Half4 operator*(const Half4 a, const Half4 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 + Half4 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#else + return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; +#endif +} + +} // anonymous namespace + +template +__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) { + int const tid = threadIdx.x; + int const token = blockIdx.x; + + output = output + token * inter_size; + input = input + token * inter_size; + for (int i = tid; i < inter_size; i += blockDim.x) { + T fc1_value = input[i]; + output[i] = fc1_value * output[i]; + } +} + +template +void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) { + int const blocks = num_tokens; + + if (inter_size & 3 == 0) { + using vec_type = typename T4::Type; + int const threads = std::min(inter_size / 4, 1024); + elementWiseMulKernel<<>>(reinterpret_cast(output), + reinterpret_cast(input), + inter_size / 4); + } else if (inter_size & 1 == 0) { + using vec_type = typename T2::Type; + int const threads = std::min(inter_size / 2, 1024); + elementWiseMulKernel<<>>(reinterpret_cast(output), + reinterpret_cast(input), + inter_size / 2); + } else { + int const threads = std::min(inter_size, 1024); + elementWiseMulKernel<<>>(output, input, inter_size); + } +} + template void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, - const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, - const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, - const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, - int* expert_for_source_row, cudaStream_t stream) { + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, + const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, + int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, + int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, + cudaStream_t stream) { static constexpr bool scales_required = std::is_same::value || std::is_same::value; - if constexpr (scales_required) { + if (scales_required) { if (fc1_scales == nullptr) { ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); } else if (fc2_scales == nullptr) { @@ -613,9 +740,10 @@ void CutlassMoeFCRunner::run_moe_fc( } } - configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k); + configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), + static_cast(inter_size), static_cast(num_experts), static_cast(k)); topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, stream); + source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, @@ -634,15 +762,48 @@ void CutlassMoeFCRunner::run_moe_fc( } // expanded_active_expert_rows is not used - moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, - fc1_expert_weights, fc1_scales, fc1_expert_biases, - fc1_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, inter_size, hidden_size, - local_num_experts, fc1_activation_type, stream); + if (fc1_expert_biases != nullptr) { + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, fc1_scales, fc1_expert_biases, + fc1_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, + local_num_experts, fc1_activation_type, stream); + } else { + moe_gemm_runner_.moe_gemm_act(permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, fc1_scales, + fc1_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, + local_num_experts, fc1_activation_type, stream); + } + + if (has_fc3_) { + if (scales_required) { + if (fc3_scales == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for third matmul is a null pointer"); + } + } else { + if (fc3_scales != nullptr) { + ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3"); + } + } + if (fc3_expert_weights == nullptr) { + ORT_THROW("[FT Error][Run MoE FC] FC3 weights are null"); + } + moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, + fc3_expert_weights, fc3_scales, fc3_expert_biases, + fc3_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, + local_num_experts, stream); + + elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, + static_cast(inter_size), static_cast(total_covered_rows_), stream); + } moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, - fc2_expert_weights, fc2_scales, + fc2_expert_weights, fc2_scales, nullptr, fc2_result + total_past_rows_ * hidden_size, total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream); @@ -651,14 +812,16 @@ void CutlassMoeFCRunner::run_moe_fc( template void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, - const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, - const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, + const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights, + const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, + int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, + int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, - fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts, - local_experts_start_index, k, workspace_ptr, fc2_result, nullptr, num_rows, expert_scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); + fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, + inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, + nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, + stream); } template @@ -811,9 +974,10 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; const int expert_idx = expert_for_source_row[k_offset]; - const T* bias_ptr = bias + expert_idx * cols; + const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; - thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]); + thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + + (bias_ptr ? bias_ptr[tid] : T(0))); } reduced_row_ptr[tid] = thread_output; } @@ -866,9 +1030,9 @@ void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* red // ========================= TopK Softmax specializations =========================== template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, - int, int, cudaStream_t); + int, int, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, - int, int, cudaStream_t); + int, int, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 5cc2a3f79f003..5eef6f95f4820 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -24,6 +24,8 @@ #include "core/common/common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" +#include "cutlass/numeric_types.h" + using namespace onnxruntime; namespace ort_fastertransformer { @@ -107,12 +109,13 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); - size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k); + size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, @@ -120,6 +123,7 @@ class CutlassMoeFCRunner { void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, + const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, @@ -135,7 +139,8 @@ class CutlassMoeFCRunner { int64_t& total_covered_rows); private: - void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); + void configure_ws_ptrs(char* ws_ptr, size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, + size_t k); private: CubKeyValueSorter sorter_; @@ -152,12 +157,17 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* fc3_result_; + + bool has_fc3_; + bool normalize_routing_weights_; // Cuda events contrib::cuda::AutoDestoryCudaEvent cuda_event_; int64_t total_past_rows_; int64_t total_covered_rows_; + // TODO: use pinned memory std::vector total_rows_before_expert_host_; }; @@ -165,11 +175,11 @@ class CutlassMoeFCRunner { template class CutlassMoeFCRunner::value>> { public: - CutlassMoeFCRunner(int sm_version); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); - size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) { + size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { return 0; } }; -} // namespace ort_fastertransformer \ No newline at end of file +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 3f26a274109ad..b13aab959fc48 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -39,13 +39,16 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc2_experts_weights = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc1_experts_bias_optional = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); const Tensor* fc2_experts_bias_optional = context->Input(5); + const Tensor* fc3_experts_weights_optional = context->Input(6); + const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, - fc1_experts_bias_optional, fc2_experts_bias_optional)); + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, + fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, + fc3_experts_bias_optional)); typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); @@ -53,12 +56,14 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_); size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), - static_cast(k_)); + moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), + static_cast(moe_params.num_experts), static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); @@ -77,26 +82,37 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr expert_for_source_row = IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); - // fc1_scales and fc2_scales are used in quantized MoE - const CudaT* fc1_scales_ptr = nullptr; - const CudaT* fc2_scales_ptr = nullptr; - + const CudaT* fc_scales_ptr = nullptr; moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), reinterpret_cast(router_probs->template Data()), - reinterpret_cast(fc1_experts_weights->template Data()), - std::move(fc1_scales_ptr), + reinterpret_cast(fc1_experts_weights->DataRaw()), + fc_scales_ptr, fc1_experts_bias_optional == nullptr ? nullptr : reinterpret_cast(fc1_experts_bias_optional->template Data()), - activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), - static_cast(moe_params.num_experts), static_cast(moe_params.local_num_experts), - 0 /*local_experts_start_index_ used in sharded MoE*/, static_cast(k_), - reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + activation_type_, + fc3_experts_weights_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->DataRaw()), + fc_scales_ptr, + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->DataRaw()), + fc_scales_ptr, + static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), + static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), + 0 /*local_experts_start_index_ used in sharded MoE*/, + static_cast(k_), + reinterpret_cast(work_space.get()), + reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), Stream(context)); + reinterpret_cast(expert_for_source_row.get()), + Stream(context)); Tensor* output = context->Output(0, input->Shape()); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index f55a7cde2e208..84a5e8c7c120d 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -13,16 +13,22 @@ namespace cuda { enum class MoEParallelType { None = 0, - ExpertSlicing = 1, + EP = 1, + TP = 2, + EPAndTP = 3, }; struct MoEParameters { + MoEParameters() {} + explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {} int64_t num_rows; int64_t num_experts; int64_t local_num_experts; int64_t hidden_size; int64_t inter_size; + MoEParallelType parallel_type; + int64_t tensor_shards{1}; }; class MoEBase { @@ -31,9 +37,11 @@ class MoEBase { const Tensor* input, const Tensor* router_probs, const Tensor* fc1_experts_weights, - const Tensor* fc2_experts_weights, const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_bias_optional) const { + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional) const { const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); @@ -83,12 +91,6 @@ class MoEBase { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", router_probs_dims[0], " and ", num_rows); } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); - } - if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); - } if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); @@ -126,15 +128,38 @@ class MoEBase { } } + if (fc3_experts_weights_optional != nullptr && + fc3_experts_weights_optional->Shape().GetDims() != fc1_experts_weights_dims) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc3_experts_weights_dims must be equal to fc1_experts_weights_dims, got ", + fc3_experts_weights_optional->Shape().GetDims(), " and ", fc1_experts_weights_dims); + } + + if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr && + fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ", + fc3_experts_bias_optional->Shape().GetDims(), " and ", + fc1_experts_bias_optional->Shape().GetDims()); + } + parameters.num_rows = num_rows; parameters.num_experts = num_experts; parameters.local_num_experts = local_num_experts; parameters.hidden_size = hidden_size; parameters.inter_size = inter_size; if (num_experts == local_num_experts) { - parameters.parallel_type = MoEParallelType::None; + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::None; + } else { + parameters.parallel_type = MoEParallelType::TP; + } } else if (num_experts > local_num_experts) { - parameters.parallel_type = MoEParallelType::ExpertSlicing; + if (parameters.tensor_shards == 1) { + parameters.parallel_type = MoEParallelType::EP; + } else { + parameters.parallel_type = MoEParallelType::EPAndTP; + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_experts must be greater than or equal to local_num_experts, got ", @@ -161,8 +186,11 @@ class MoEBase { } else { ORT_THROW("Unsupported MoE activation type: ", activation_type_str); } + + normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; } + bool normalize_routing_weights_; int64_t k_; ort_fastertransformer::ActivationType activation_type_; }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 4aa43f5de1cd5..a0ca2e45f153a 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -91,10 +91,18 @@ void RegisterCollectiveOps() { "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Attr("normalize_routing_weights", + "Whether to normalize routing weights", + AttributeProto::INT, + static_cast(0)) .Attr("local_experts_start_index", "The start index of local experts", AttributeProto::INT, - static_cast(-1)) + static_cast(0)) + .Attr("tensor_shards", + "Tensor parallelism config. The number of shards for each expert weight and bias", + AttributeProto::INT, + static_cast(1)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or " @@ -106,22 +114,32 @@ void RegisterCollectiveOps() { "T") .Input(2, "fc1_experts_weights", - "3D input tensor with shape (local_num_experts, hidden_size, inter_size)", + "3D input tensor with shape (local_num_experts, hidden_size, local_inter_size)", "T") .Input(3, - "fc2_experts_weights", - "3D input tensor with shape (local_num_experts, inter_size, hidden_size)", - "T") - .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (local_num_experts, inter_size)", + "2D optional input tensor with shape (local_num_experts, local_inter_size)", "T", OpSchema::Optional) + .Input(4, + "fc2_experts_weights", + "3D input tensor with shape (local_num_experts, local_inter_size, hidden_size)", + "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(6, + "fc3_experts_weights", + "3D optional input tensor with shape (local_num_experts, hidden_size, local_inter_size)", + "T", + OpSchema::Optional) + .Input(7, + "fc3_experts_bias", + "2D optional input tensor with shape (local_num_experts, local_inter_size)", + "T", + OpSchema::Optional) .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or " diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6709398c788f0..82cc16acad582 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1382,8 +1382,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, constexpr const char* MoE_ver1_doc = R"DOC( Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, - GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) - usually uses top 32 experts. + GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) + usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral). )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, @@ -1391,12 +1391,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .SetDoc(MoE_ver1_doc) .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) + .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") - .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) + .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) + .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index bed2f677166d6..1cd3532846114 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -543,7 +543,7 @@ struct _IsNan { template <> struct _IsNan { __device__ __inline__ bool operator()(half a) const { - return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) > MLFloat16::kPositiveInfinityBits; } }; @@ -551,7 +551,7 @@ struct _IsNan { template <> struct _IsNan { __device__ __inline__ bool operator()(BFloat16 a) const { - return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) > BFloat16::kPositiveInfinityBits; } }; diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index e88ef7794cd07..263ace25ddfe0 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -14,6 +14,7 @@ static void RunMoETest( const std::vector& router_probs, const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, + const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, const std::vector& fc2_experts_bias, const std::vector& output_data, @@ -22,19 +23,23 @@ static void RunMoETest( int hidden_size, int inter_size, std::string activation_type, + int normalize_routing_weights = 0, + int top_k = 1, bool use_float16 = false) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); if (enable_cuda) { OpTester tester("MoE", 1, onnxruntime::kMSDomain); - tester.AddAttribute("k", static_cast(1)); + tester.AddAttribute("k", static_cast(top_k)); tester.AddAttribute("activation_type", activation_type); + tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); std::vector input_dims = {num_rows, hidden_size}; std::vector router_probs_dims = {num_rows, num_experts}; std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size}; std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; std::vector fc1_experts_bias_dims = {num_experts, inter_size}; std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; std::vector output_dims = {num_rows, hidden_size}; @@ -43,18 +48,40 @@ static void RunMoETest( tester.AddInput("input", input_dims, ToFloat16(input)); tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, ToFloat16(fc1_experts_weights)); + if (!fc1_experts_bias.empty()) { + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias)); + } else { + tester.AddOptionalInputEdge(); + } tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, ToFloat16(fc2_experts_weights)); - tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, ToFloat16(fc1_experts_bias)); - tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias)); + if (!fc2_experts_bias.empty()) { + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, ToFloat16(fc2_experts_bias)); + } else { + tester.AddOptionalInputEdge(); + } + if (!fc3_experts_weights.empty()) { + tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, ToFloat16(fc3_experts_weights)); + } tester.AddOutput("output", output_dims, ToFloat16(output_data)); tester.SetOutputTolerance(0.005f); } else { tester.AddInput("input", input_dims, input); tester.AddInput("router_probs", router_probs_dims, router_probs); tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + if (!fc1_experts_bias.empty()) { + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); - tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); - tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + if (!fc2_experts_bias.empty()) { + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + if (!fc3_experts_weights.empty()) { + tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + } tester.AddOutput("output", output_dims, output_data); tester.SetOutputTolerance(0.001f); } @@ -233,6 +260,7 @@ TEST(MoETest, MoETest_Gelu) { router_probs, fc1_experts_weights, fc2_experts_weights, + {}, fc1_experts_bias, fc2_experts_bias, output, @@ -411,6 +439,7 @@ TEST(MoETest, MoETest_Relu) { router_probs, fc1_experts_weights, fc2_experts_weights, + {}, fc1_experts_bias, fc2_experts_bias, output, @@ -421,5 +450,143 @@ TEST(MoETest, MoETest_Relu) { "relu"); } +TEST(MoETest, MoETest_Mixtral) { + int num_rows = 6; + int num_experts = 8; + int hidden_size = 4; + int inter_size = 8; + + const std::vector input = { + 0.9212995f, 0.5282444f, -0.008228387f, -1.449332f, -0.6051824f, -0.17924511f, 0.1995587f, -1.2461947f, + 0.86708033f, 0.19191018f, 1.1600108f, -0.008815222f, 0.8504777f, -0.84964496f, -1.4019964f, 0.17225051f, + 0.35569248f, 1.2056456f, 1.3690308f, -0.69495815f, 1.4324434f, 0.22761835f, -1.1286871f, 1.124213f}; + const std::vector router_probs = { + -0.09331456f, -0.47121337f, 0.07311103f, 0.47643483f, 0.21135253f, -0.72226393f, -0.048502743f, 0.39447474f, + -0.9014899f, -0.36629856f, -0.23088816f, -0.099606544f, -0.45191774f, -0.30394578f, 0.6266495f, 0.67937183f, + 0.27117345f, -0.36059442f, 0.81510246f, 0.61359257f, 0.07649982f, -0.44949868f, -0.54758865f, 0.4736983f, + 0.21584567f, 0.21296778f, 0.093342215f, -0.09353682f, 0.61422515f, 0.19574627f, 0.0063361377f, -0.2465148f, + 0.15675665f, -0.4546509f, 0.24447554f, 0.5921611f, -0.18192923f, -0.66116416f, -0.40265432f, 0.33475468f, + 1.2906091f, 0.4709078f, 0.16256471f, 0.19308007f, 0.97568524f, 0.25876164f, -0.7964541f, -1.0319631f}; + const std::vector fc1_experts_weights = { + 0.3860137f, 0.077925384f, 0.13434184f, 0.28902978f, 0.25391752f, -0.38351142f, 0.15813059f, 0.031481862f, + 0.083209574f, 0.4039817f, -0.13558972f, -0.21858627f, -0.30475253f, 0.41026944f, -0.008697987f, -0.3412701f, + -0.16235226f, 0.054659843f, 0.21042877f, 0.28863233f, -0.49495423f, 0.14401567f, 0.39130414f, 0.154176f, + 0.30897498f, -0.15768659f, 0.44641107f, 0.089463115f, -0.19318026f, 0.20710677f, -0.3552568f, -0.17219114f, + 0.41923493f, -0.4233985f, -0.41503525f, 0.19466156f, -0.08633667f, 0.45547962f, -0.054792404f, 0.26722562f, + -0.09923202f, 0.3460176f, -0.49708033f, -0.41033173f, 0.10443485f, -0.39646107f, -0.37424505f, 0.1757198f, + 0.43019837f, -0.13757241f, 0.14305532f, 0.37121457f, 0.2581259f, 0.12583363f, 0.45542932f, 0.16247797f, + 0.15579104f, -0.19166303f, -0.109221935f, -0.36702687f, 0.40365517f, -0.21506298f, -0.36697525f, -0.2703231f, + -0.49740213f, -0.3486371f, 0.24005288f, -0.0048963428f, 0.20468098f, -0.09111178f, -0.1485982f, -0.088219464f, + 0.33463532f, -0.49346995f, 0.42075223f, -0.38025302f, -0.245484f, -0.35191745f, 0.3086716f, -0.2423737f, + 0.37881732f, -0.40608948f, 0.26193494f, -0.4283861f, -0.10062629f, -0.32670784f, -0.16040438f, -0.15297079f, + 0.1822241f, 0.37285012f, 0.12654608f, -0.46767431f, -0.28775263f, 0.16585541f, -0.36678362f, -0.4759978f, + -0.34751755f, -0.3163945f, -0.3858195f, -0.38030273f, -0.06156373f, -0.04352224f, -0.4041785f, -0.335764f, + -0.10303855f, -0.4009425f, -0.1236487f, -0.40111196f, 0.23985302f, -0.118291676f, -0.26773083f, 0.121197104f, + 0.3702919f, -0.34168184f, 0.33743858f, 0.24873763f, -0.23140603f, -0.25351608f, 0.48291886f, 0.13780516f, + 0.25632292f, -0.49343884f, 0.08369112f, -0.37192065f, -0.05451995f, -0.44571918f, -0.24150735f, 0.27395487f, + -0.20423341f, -0.024149835f, 0.40208143f, -0.18211937f, -0.19767642f, -0.19397742f, -0.1510992f, 0.48074025f, + 0.18377024f, -0.18288034f, 0.08111167f, 0.12729281f, 0.27861303f, 0.0076527f, 0.36356348f, -0.24359548f, + -0.33313757f, -0.374829f, -0.08705664f, 0.23576546f, -0.39819986f, -0.09880793f, -0.012998581f, -0.36475456f, + -0.32685202f, 0.29657948f, -0.4631365f, -0.06320876f, 0.31600899f, 0.060619473f, 0.39029974f, 0.401151f, + 0.15562236f, 0.43565983f, -0.058149397f, 0.36150748f, 0.10750586f, -0.063970566f, -0.47026545f, -0.3035437f, + -0.38143605f, -0.4734699f, 0.31273925f, -0.43410504f, 0.07299572f, 0.47506f, 0.021913886f, -0.036100805f, + -0.31637233f, 0.37718338f, -0.046213806f, 0.19239199f, 0.13676548f, 0.33592474f, -0.34048676f, -0.11097133f, + -0.41569126f, -0.01680845f, 0.31357706f, 0.0943895f, -0.24053341f, -0.018784225f, 0.40659577f, 0.08897692f, + 0.3793823f, -0.3271106f, 0.067666054f, -0.12331611f, -0.010209799f, -0.48908865f, 0.19195485f, -0.45211792f, + 0.48282713f, 0.4363466f, -0.40184838f, -0.025082052f, -0.31057972f, 0.14850605f, 0.39756012f, -0.25782883f, + 0.3181312f, 0.17685872f, -0.16694272f, -0.41516554f, -0.062004805f, -0.33060408f, -0.13665432f, -0.43781847f, + -0.298562f, 0.013283849f, 0.48130906f, -0.27970356f, 0.20347959f, -0.24402553f, -0.20528454f, -0.114435256f, + 0.12556863f, -0.4344011f, 0.2868948f, 0.19894183f, -0.12849897f, -0.18726158f, -0.4850099f, -0.4352169f, + -0.40527463f, 0.13625044f, -0.49707252f, -0.45698053f, 0.28196156f, 0.16826987f, -0.25944453f, 0.2801003f, + 0.21121234f, -0.04066527f, 0.45854944f, -0.17861038f, 0.18178529f, 0.17789757f, 0.34227383f, 0.26976448f, + 0.15789884f, 0.22840887f, 0.419321f, -0.14490443f, 0.39608955f, -0.4162954f, -0.47072983f, 0.41119635f}; + const std::vector fc2_experts_weights = { + 0.10833451f, 0.34020698f, -0.18258394f, -0.17842063f, -0.07365984f, -0.29177922f, -0.24102151f, 0.1077901f, + 0.2932343f, -0.35068116f, 0.1875877f, 0.07474385f, -0.20955177f, -0.27660736f, -0.14290786f, -0.09014153f, + -0.21085852f, -0.2378315f, 0.21457997f, 0.21074237f, -0.21087126f, 0.14320332f, -0.08389844f, 0.24034885f, + 0.31800103f, 0.12659892f, 0.20224877f, -0.2563875f, 0.11782206f, 0.29377612f, -0.27469966f, -0.18875091f, + 0.32136288f, 0.0788243f, -0.26413083f, 0.18453442f, 0.0776935f, -0.19561274f, 0.12608862f, 0.18579696f, + 0.045481127f, -0.17894714f, 0.27366453f, 0.13220324f, -0.3115706f, -0.016884197f, -0.3328494f, -0.062126897f, + 0.14841764f, 0.19741052f, 0.08211302f, -0.09362138f, -0.053040292f, -0.090344846f, 0.18264277f, 0.037823465f, + -0.16197139f, -0.20172869f, 0.064109616f, -0.062456656f, 0.30368346f, -0.12107184f, -0.12590908f, -0.10535928f, + 0.1978099f, 0.13119277f, 0.21948591f, -0.080250844f, -0.24614547f, 0.33202717f, 0.2645375f, -0.21193951f, + 0.17770219f, -0.04986229f, 0.33435768f, -0.0309231f, 0.16043694f, -0.0027341924f, -0.08339601f, -0.17402375f, + 0.2525901f, -0.0813988f, -0.2904943f, -0.14452116f, -0.27119386f, -0.2952116f, 0.0794895f, -0.11223866f, + 0.25427446f, 0.16967128f, 0.19531254f, -0.33598322f, -0.16714293f, -0.35097876f, -0.35189477f, 0.2900932f, + 0.26874313f, -0.1322388f, -0.330179f, 0.064027935f, 0.19688474f, -0.20129368f, 0.006225848f, 0.19252343f, + -0.35054854f, -0.31874785f, 0.32238203f, 0.29287276f, 0.03135616f, 0.015792634f, 0.20397249f, -0.3245995f, + 0.21416605f, 0.15667121f, -0.2058509f, 0.23639117f, -0.032677338f, 0.07826358f, -0.04589425f, -0.24935842f, + -0.20834164f, 0.069915086f, -0.26063374f, 0.13239416f, 0.33705652f, -0.26813045f, -0.17056243f, 0.29919288f, + 0.27704936f, -0.096224755f, 0.13250813f, 0.26709175f, -0.26995474f, 0.3261805f, -0.18062393f, -0.04732303f, + -0.02733084f, 0.050550338f, -0.2937818f, -0.19453493f, -0.34864828f, -0.20862648f, -0.19311349f, 0.17665526f, + -0.2894185f, -0.020016002f, 0.3409702f, -0.18320526f, 0.068286195f, 0.08490415f, 0.30223787f, -0.2386011f, + 0.09405743f, 0.123811804f, 0.31660154f, -0.11290163f, 0.07494662f, -0.24999082f, 0.2075398f, 0.07419645f, + 0.3327035f, -0.09647329f, 0.24138254f, -0.32546985f, 0.033594366f, 0.16555631f, 0.33516192f, -0.32619375f, + 0.20476541f, -0.07724f, 0.018923176f, -0.21126744f, 0.2744358f, -0.23979841f, -0.30413106f, -0.3485449f, + 0.2854276f, 0.14391156f, -0.24802732f, -0.21701548f, -0.122100174f, 0.054206114f, -0.21961808f, 0.13481297f, + -0.07907457f, 0.15763119f, -0.31156835f, 0.29488218f, 0.17039073f, 0.35125035f, -0.17721775f, -0.10516899f, + 0.072144486f, -0.038529005f, -0.058253434f, 0.13062657f, -0.3312356f, -0.15963489f, -0.20129326f, 0.014987925f, + 0.30869225f, 0.283981f, -0.057181682f, 0.15174268f, 0.22181617f, -0.19763571f, 0.28675067f, 0.0003976555f, + -0.34610963f, 0.2931936f, -0.26233214f, 0.19563977f, -0.16886877f, 0.022812065f, 0.080249704f, -0.2798801f, + 0.11531327f, 0.07107194f, -0.34746924f, -0.051920194f, -0.07264093f, 0.27581826f, 0.18536879f, 0.15684144f, + -0.26691115f, -0.22811417f, -0.1498502f, -0.176639f, -0.25876564f, -0.16051741f, -0.0048792143f, -0.08490091f, + 0.18136817f, 0.24729891f, 0.32358363f, -0.09566104f, 0.3074607f, -0.24191524f, -0.21220984f, -0.23039621f, + 0.21154472f, -0.19495378f, 0.002779711f, -0.34692943f, 0.055384878f, 0.25809082f, 0.16814983f, 0.19935164f, + 0.11652225f, 0.1115539f, -0.24407779f, 0.09392998f, 0.33556697f, 0.11422251f, 0.34336287f, -0.33113837f}; + const std::vector fc3_experts_weights = { + 0.45783097f, -0.2863351f, 0.011728346f, -0.43760604f, 0.15407985f, 0.07818556f, 0.0013856292f, -0.34319758f, + -0.16871625f, 0.12490183f, -0.34154075f, -0.31836903f, -0.46634215f, -0.43996066f, -0.1860516f, -0.2917009f, + -0.1772582f, -0.06599659f, -0.42419833f, 0.49980444f, -0.3283869f, -0.21543652f, -0.034647882f, -0.17114872f, + -0.4837973f, -0.362943f, -0.27533132f, 0.09443748f, -0.16642791f, -0.2993343f, -0.33881485f, -0.39464045f, + 0.31960344f, 0.007296145f, -0.45412838f, -0.024868786f, -0.16298121f, -0.44197202f, 0.07232875f, -0.32362783f, + 0.42969978f, -0.029854119f, -0.18451887f, -0.30145288f, 0.16885209f, -0.30068123f, -0.12948537f, 0.36494362f, + -0.049498677f, 0.12020564f, 0.42106473f, -0.30590254f, 0.31881082f, -0.078908324f, 0.20685762f, -0.22735089f, + -0.11194843f, 0.14011681f, 0.19477749f, -0.44788343f, 0.23084867f, 0.48367476f, -0.19044077f, -0.100233376f, + 0.4191656f, -0.4515314f, -0.3214385f, 0.016065598f, -0.4069137f, -0.17348295f, -0.43329984f, 0.33521235f, + -0.07843453f, -0.4865722f, -0.039011598f, -0.10605621f, 0.4192536f, 0.04063064f, 0.1984514f, 0.49294376f, + -0.056941032f, 0.18582922f, -0.16650558f, -0.17215621f, -0.20009357f, 0.46615022f, 0.47462142f, -0.0766145f, + -0.20405996f, -0.27452308f, -0.16176039f, -0.23940295f, 0.13248974f, 0.23036134f, 0.13154167f, 0.10377723f, + 0.0070211887f, 0.29162645f, 0.34465307f, -0.4058748f, -0.13989884f, -0.12305027f, -0.2541607f, 0.4767149f, + 0.4549045f, -0.108933926f, 0.2452516f, 0.054080307f, 0.33768386f, -0.45279485f, 0.1557768f, 0.17416143f, + -0.42602575f, -0.102350116f, 0.16022503f, 0.14813942f, 0.03982985f, -0.47012872f, -0.14555538f, 0.35645115f, + -0.1909796f, -0.20839584f, -0.28098184f, -0.23085594f, 0.022559166f, -0.23900753f, -0.19561106f, -0.24205637f, + 0.2573983f, -0.2947166f, 0.4568925f, 0.11514187f, 0.18671238f, -0.121082425f, 0.3909887f, -0.10985571f, + -0.19420451f, -0.3255307f, 0.4863913f, 0.007830441f, 0.4648854f, -0.24156213f, 0.22956276f, -0.09216207f, + -0.29428315f, 0.26062596f, 0.14955276f, -0.036366224f, -0.12957954f, 0.08501935f, -0.36796576f, 0.041123867f, + 0.06744653f, -0.0839923f, 0.17207885f, 0.006872058f, -0.21135789f, 0.3732242f, -0.2683524f, -0.45898575f, + -0.14543939f, 0.30806476f, 0.08574325f, 0.027492225f, -0.38164973f, -0.040038824f, -0.26947904f, -0.09740937f, + 0.26697665f, -0.43565083f, 0.1359719f, 0.12271714f, 0.0149876475f, -0.44011843f, 0.26128954f, -0.42487514f, + -0.24668545f, 0.06113738f, -0.29119557f, 0.194273f, -0.24981815f, 0.3489496f, -0.47321397f, -0.31794417f, + -0.23641628f, 0.44169098f, -0.006898284f, 0.43446392f, -0.39553195f, 0.057907403f, -0.19339961f, -0.08160931f, + 0.4979084f, -0.11149913f, 0.35366338f, -0.16032219f, -0.48278677f, 0.08397317f, 0.4008311f, 0.30288273f, + 0.2546957f, -0.10675722f, 0.069722414f, 0.456497f, -0.19691509f, 0.49017924f, 0.41796166f, -0.2337895f, + -0.3635872f, -0.45445484f, -0.29122698f, -0.4339773f, 0.15762383f, 0.09782606f, -0.27986187f, -0.23860168f, + 0.38454843f, -0.07870716f, 0.15390605f, -0.15793777f, 0.48130733f, 0.288768f, 0.45969498f, -0.4193731f, + -0.3218134f, -0.29914904f, -0.3426242f, 0.06931591f, -0.2633695f, -0.25429398f, 0.25366426f, -0.27700734f, + 0.49418402f, -0.21919805f, 0.041192472f, -0.19817531f, -0.49578953f, 0.48185098f, -0.41920406f, -0.08335745f, + 0.19111753f, -0.07547706f, 0.049694f, 0.13012594f, 0.2617172f, -0.22612399f, 0.32247066f, -0.33702326f, + 0.20062232f, -0.09143996f, -0.063310504f, 0.1885702f, 0.11926836f, 0.3378734f, -0.45973647f, 0.48845494f}; + const std::vector output = { + 0.026516449f, 0.04061616f, 0.04403834f, -0.13644142f, 0.038774252f, 0.024002096f, -0.061423667f, 0.034824893f, + -0.022858473f, 0.04693405f, -0.0120724365f, -0.028846134f, -0.0168579f, -0.07958221f, 0.048179876f, 0.053492386f, + -0.026292695f, -0.009724421f, -0.026503641f, 0.031220898f, 0.04189077f, 0.11775493f, -0.037770163f, -0.0790936f}; + + RunMoETest(input, + router_probs, + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + {}, + {}, + output, + num_rows, + num_experts, + hidden_size, + inter_size, + "silu", + 1, /*normalize_routing_weights*/ + 2 /*top_k*/); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py index fd1d58cd2a3b8..ec64f2359f4be 100644 --- a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py +++ b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py @@ -24,25 +24,17 @@ def get_size(): return comm.Get_size() -def barrier(): - comm.Barrier() - - def print_out(*args): if get_rank() == 0: print(*args) -def broadcast(data): - comm = MPI.COMM_WORLD - comm.broadcast(data, root=0) - - local_rank = get_rank() ORT_DTYPE = TensorProto.FLOAT16 NP_TYPE = np.float16 if ORT_DTYPE == TensorProto.FLOAT16 else np.float32 -THRESHOLD = 1e-3 +THRESHOLD_TP = 3e-2 +THRESHOLD_EP = 1e-6 def create_moe_onnx_graph( @@ -52,12 +44,17 @@ def create_moe_onnx_graph( hidden_size, inter_size, fc1_experts_weights, - fc2_experts_weights, fc1_experts_bias, + fc2_experts_weights, fc2_experts_bias, - local_experts_start_index=-1, + fc3_experts_weights, + local_experts_start_index=0, + topk=2, + normalize_routing_weights=1, + activation_type="gelu", + tensor_shards=1, ): - use_sharded_moe = local_experts_start_index >= 0 + use_sharded_moe = num_experts > local_num_experts or tensor_shards > 1 nodes = [ ( helper.make_node( @@ -66,14 +63,16 @@ def create_moe_onnx_graph( "input", "router_probs", "fc1_experts_weights", - "fc2_experts_weights", "fc1_experts_bias", + "fc2_experts_weights", "fc2_experts_bias", + "fc3_experts_weights", ], ["output"], "MoE_0", - k=1, - activation_type="gelu", + k=topk, + normalize_routing_weights=normalize_routing_weights, + activation_type=activation_type, domain="com.microsoft", ) if not use_sharded_moe @@ -83,15 +82,18 @@ def create_moe_onnx_graph( "input", "router_probs", "fc1_experts_weights", - "fc2_experts_weights", "fc1_experts_bias", + "fc2_experts_weights", "fc2_experts_bias", + "fc3_experts_weights", ], ["output"], "MoE_0", - k=1, - activation_type="gelu", + k=topk, + normalize_routing_weights=normalize_routing_weights, + activation_type=activation_type, local_experts_start_index=local_experts_start_index, + tensor_shards=tensor_shards, domain="com.microsoft", ) ), @@ -99,6 +101,7 @@ def create_moe_onnx_graph( fc1_shape = [local_num_experts, hidden_size, inter_size] fc2_shape = [local_num_experts, inter_size, hidden_size] + fc3_shape = fc1_shape initializers = [ helper.make_tensor( @@ -115,6 +118,13 @@ def create_moe_onnx_graph( fc2_experts_weights.flatten(), raw=False, ), + helper.make_tensor( + "fc3_experts_weights", + ORT_DTYPE, + fc3_shape, + fc3_experts_weights.flatten(), + raw=False, + ), ] fc1_bias_shape = [local_num_experts, inter_size] @@ -166,18 +176,18 @@ def create_moe_onnx_graph( return model.SerializeToString() -def test_moe_with_expert_slicing( +def generate_weights_and_initial_model( + num_rows, + num_experts, hidden_size, inter_size, - num_experts, - num_rows, ): - local_experts_start_index = local_rank * num_experts // get_size() - - fc1_experts_weights_all = np.random.rand(num_experts, hidden_size, inter_size).astype(NP_TYPE) - fc2_experts_weights_all = np.random.rand(num_experts, inter_size, hidden_size).astype(NP_TYPE) - fc1_experts_bias_all = np.random.rand(num_experts, inter_size).astype(NP_TYPE) - fc2_experts_bias_all = np.random.rand(num_experts, hidden_size).astype(NP_TYPE) + s = 0.1 + fc1_experts_weights_all = np.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE) + fc2_experts_weights_all = np.random.normal(scale=s, size=(num_experts, inter_size, hidden_size)).astype(NP_TYPE) + fc3_experts_weights_all = np.random.normal(scale=s, size=(num_experts, hidden_size, inter_size)).astype(NP_TYPE) + fc1_experts_bias_all = np.random.normal(scale=s, size=(num_experts, inter_size)).astype(NP_TYPE) + fc2_experts_bias_all = np.random.normal(scale=s, size=(num_experts, hidden_size)).astype(NP_TYPE) onnx_model_full = create_moe_onnx_graph( num_rows, @@ -186,34 +196,31 @@ def test_moe_with_expert_slicing( hidden_size, inter_size, fc1_experts_weights_all, - fc2_experts_weights_all, fc1_experts_bias_all, + fc2_experts_weights_all, fc2_experts_bias_all, + fc3_experts_weights_all, ) - fc1_experts_weights = fc1_experts_weights_all[ - local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : - ] - fc2_experts_weights = fc2_experts_weights_all[ - local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : - ] - fc1_experts_bias = fc1_experts_bias_all[ - local_experts_start_index : local_experts_start_index + num_experts // get_size(), : - ] - - onnx_model_local = create_moe_onnx_graph( - num_rows, - num_experts, - num_experts // get_size(), - hidden_size, - inter_size, - fc1_experts_weights, - fc2_experts_weights, - fc1_experts_bias, + return ( + onnx_model_full, + fc1_experts_weights_all, + fc1_experts_bias_all, + fc2_experts_weights_all, fc2_experts_bias_all, - local_experts_start_index, + fc3_experts_weights_all, ) + +def run_ort_with_parity_check( + onnx_model_full, + onnx_model_local, + num_rows, + hidden_size, + num_experts, + inter_size, + threshold, +): sess_options = onnxruntime.SessionOptions() cuda_provider_options = {"device_id": local_rank} execution_providers = [("CUDAExecutionProvider", cuda_provider_options)] @@ -229,30 +236,161 @@ def test_moe_with_expert_slicing( output = ort_session.run(None, ort_inputs) sharded_output = ort_session_local.run(None, ort_inputs) - assert np.allclose(output[0], sharded_output[0], atol=THRESHOLD, rtol=THRESHOLD) + print_out("max diff:", np.max(np.abs(output[0] - sharded_output[0]))) + assert np.allclose(output[0], sharded_output[0], atol=threshold, rtol=threshold) print_out( - "hidden_size: ", + "hidden_size:", hidden_size, - " inter_size: ", + " inter_size:", inter_size, - " num_experts: ", + " num_experts:", num_experts, - " num_rows: ", + " num_rows:", num_rows, - " world_size: ", + " world_size:", get_size(), " Parity: OK", ) +def test_moe_with_tensor_parallelism( + hidden_size, + inter_size, + num_experts, + num_rows, + threshold=THRESHOLD_TP, +): + assert inter_size % get_size() == 0 + + ( + onnx_model_full, + fc1_experts_weights_all, + fc1_experts_bias_all, + fc2_experts_weights_all, + fc2_experts_bias_all, + fc3_experts_weights_all, + ) = generate_weights_and_initial_model( + num_rows, + num_experts, + hidden_size, + inter_size, + ) + + fc1_experts_weights = fc1_experts_weights_all[ + :, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size() + ] + fc2_experts_weights = fc2_experts_weights_all[ + :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size(), : + ] + fc3_experts_weights = fc3_experts_weights_all[ + :, :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size() + ] + fc1_experts_bias = fc1_experts_bias_all[ + :, local_rank * inter_size // get_size() : (local_rank + 1) * inter_size // get_size() + ] + + onnx_model_local = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts, + hidden_size, + inter_size // get_size(), + fc1_experts_weights, + fc1_experts_bias, + fc2_experts_weights, + fc2_experts_bias_all, + fc3_experts_weights, + tensor_shards=get_size(), + ) + + run_ort_with_parity_check( + onnx_model_full, + onnx_model_local, + num_rows, + hidden_size, + num_experts, + inter_size, + threshold, + ) + + +def test_moe_with_expert_parallelism( + hidden_size, + inter_size, + num_experts, + num_rows, + threshold=THRESHOLD_EP, +): + local_experts_start_index = local_rank * num_experts // get_size() + + ( + onnx_model_full, + fc1_experts_weights_all, + fc1_experts_bias_all, + fc2_experts_weights_all, + fc2_experts_bias_all, + fc3_experts_weights_all, + ) = generate_weights_and_initial_model( + num_rows, + num_experts, + hidden_size, + inter_size, + ) + + fc1_experts_weights = fc1_experts_weights_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : + ] + fc2_experts_weights = fc2_experts_weights_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : + ] + fc3_experts_weights = fc3_experts_weights_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : + ] + fc1_experts_bias = fc1_experts_bias_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), : + ] + + onnx_model_local = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts // get_size(), + hidden_size, + inter_size, + fc1_experts_weights, + fc1_experts_bias, + fc2_experts_weights, + fc2_experts_bias_all, + fc3_experts_weights, + local_experts_start_index, + ) + + run_ort_with_parity_check( + onnx_model_full, + onnx_model_local, + num_rows, + hidden_size, + num_experts, + inter_size, + threshold, + ) + + class TestMoE(unittest.TestCase): - def test_moe_expert_slicing(self): - for hidden_size in [16, 128]: - for inter_size in [512, 1024]: - for num_experts in [8, 16, 32]: - for num_rows in [16, 128, 512]: - test_moe_with_expert_slicing( + def test_moe_parallelism(self): + for hidden_size in [128, 1024]: + for inter_size in [512, 2048]: + for num_experts in [64]: + for num_rows in [1024]: + print_out("EP") + test_moe_with_expert_parallelism( + hidden_size, + inter_size, + num_experts, + num_rows, + ) + print_out("TP") + test_moe_with_tensor_parallelism( hidden_size, inter_size, num_experts, diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py new file mode 100644 index 0000000000000..90b7da255081a --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -0,0 +1,365 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest +from collections import OrderedDict + +import numpy +import torch +import torch.nn.functional as F +from onnx import TensorProto, helper +from torch import nn + +import onnxruntime + +torch.manual_seed(42) +numpy.random.seed(42) + +ORT_DTYPE = TensorProto.FLOAT +NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 +THRESHOLD = 3e-2 + + +def value_string_of(numpy_array): + arr = numpy_array.flatten() + lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] + return "{\n " + "f,\n ".join(lines) + "f}" + + +def print_tensor(name, numpy_array): + print(f"const std::vector {name} = {value_string_of(numpy_array)};") + + +def create_moe_onnx_graph( + num_rows, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + topk, +): + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ], + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="silu", + domain="com.microsoft", + ), + ] + + fc1_shape = [num_experts, hidden_size, inter_size] + fc2_shape = [num_experts, inter_size, hidden_size] + fc3_shape = [num_experts, hidden_size, inter_size] + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc3_experts_weights", + ORT_DTYPE, + fc3_shape, + fc3_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +class MixtralConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + rope_theta=1e6, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states_1 = self.act_fn(self.w1(hidden_states)) + current_hidden_states_3 = self.w3(hidden_states) + current_hidden_states = current_hidden_states_1 * current_hidden_states_3 + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config, batch_size, sequence_length): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + w1_list = [] + w2_list = [] + w3_list = [] + for i in range(self.num_experts): + w1_list.append(self.experts[i].w1.weight.transpose(0, 1)) + w2_list.append(self.experts[i].w2.weight.transpose(0, 1)) + w3_list.append(self.experts[i].w3.weight.transpose(0, 1)) + + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(w2_list, dim=0) + self.moe_experts_weight3 = torch.stack(w3_list, dim=0) + + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_moe_onnx_graph( + self.batch_size * self.sequence_length, + self.num_experts, + self.hidden_dim, + self.ffn_dim, + self.moe_experts_weight1, + self.moe_experts_weight2, + self.moe_experts_weight3, + self.top_k, + ) + + self.ort_sess = self.create_ort_session() + + def create_ort_session(self): + from onnxruntime import InferenceSession, SessionOptions + + sess_options = SessionOptions() + + cuda_providers = ["CUDAExecutionProvider"] + if cuda_providers[0] not in onnxruntime.get_available_providers(): + return None + + sess_options.log_severity_level = 2 + ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states # , router_logits + + def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + ort_inputs = { + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + } + + ort_output = None + if self.ort_sess is not None: + ort_output = self.ort_sess.run(None, ort_inputs) + return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) + # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) + # print_tensor("output", ort_output[0]) + + return None + + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + if ort_output is not None: + assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) + print( + "batch_size:", + self.batch_size, + " sequence_length:", + self.sequence_length, + " max_diff:", + (torch_output - ort_output).abs().max(), + " parity: OK", + ) + + +class TestMixtralMoE(unittest.TestCase): + def test_mixtral_moe_parity(self): + for batch_size in [1, 16]: + for sequence_length in [128, 1024]: + # use a small sizes to speed up the test + config = MixtralConfig(hidden_size=256, intermediate_size=1024) + mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.parity_check() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index 72ca5d9975c05..dbf6ee7dabb0e 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -47,8 +47,8 @@ def create_moe_onnx_graph( hidden_size, inter_size, fc1_experts_weights, - fc2_experts_weights, fc1_experts_bias, + fc2_experts_weights, fc2_experts_bias, ): nodes = [ @@ -58,8 +58,8 @@ def create_moe_onnx_graph( "input", "router_probs", "fc1_experts_weights", - "fc2_experts_weights", "fc1_experts_bias", + "fc2_experts_weights", "fc2_experts_bias", ], ["output"], @@ -250,8 +250,8 @@ def __init__( in_features, hidden_features, self.moe_experts.weight1, - self.moe_experts.weight2, self.moe_experts.bias1, + self.moe_experts.weight2, self.moe_experts.bias2, ) @@ -296,8 +296,6 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): ).data_ptr(), ) - iobinding.synchronize_inputs() - iobinding.bind_output( name="output", device_type="cuda", @@ -308,11 +306,12 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): numpy.zeros(ort_inputs["input"].shape), "cuda", device_id ).data_ptr(), ) - iobinding.synchronize_outputs() s = time.time() for _ in range(repeat): + iobinding.synchronize_inputs() self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() e = time.time() print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") @@ -356,8 +355,8 @@ def onnx_forward(self, iobinding=False): # print_tensor("input", ort_inputs["input"]) # print_tensor("router_probs", ort_inputs["router_probs"]) # print_tensor("fc1_experts_weights", self.moe_experts.weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy()) # print_tensor("fc1_experts_bias", self.moe_experts.bias1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy()) # print_tensor("fc2_experts_bias", self.moe_experts.bias2.detach().numpy()) # print_tensor("output", ort_output[0])