From 6ff31e06d5757779b9c8d53e9d02a3b62b3e3438 Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Tue, 19 Mar 2024 21:28:15 -0700
Subject: [PATCH] [MoE] Add TP and Mixtral MoE (#19945)
### Description
1.Support Tensor Parallelism in ShardedMoE.
2.Make necessary code changes to support Mixtral MoE.
3.Fix a bug related to using IOBinding in test script.
4.Fix the input size limitation
### Motivation and Context
---
docs/ContribOperators.md | 16 +-
docs/OperatorKernels.md | 2 +-
.../cuda/collective/sharded_moe.cc | 113 ++++--
.../contrib_ops/cuda/collective/sharded_moe.h | 1 +
.../cuda/moe/ft_moe/epilogue_helpers.h | 33 +-
.../cuda/moe/ft_moe/moe_gemm_kernels.h | 9 +-
.../moe/ft_moe/moe_gemm_kernels_template.h | 48 ++-
.../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 304 +++++++++++----
.../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 22 +-
onnxruntime/contrib_ops/cuda/moe/moe.cc | 58 ++-
onnxruntime/contrib_ops/cuda/moe/moe_base.h | 50 ++-
.../core/graph/contrib_ops/collective_defs.cc | 32 +-
.../core/graph/contrib_ops/contrib_defs.cc | 11 +-
.../core/providers/cuda/cu_inc/common.cuh | 4 +-
onnxruntime/test/contrib_ops/moe_test.cc | 177 ++++++++-
.../sharded_moe/test_sharded_moe.py | 260 ++++++++++---
.../transformers/test_parity_mixtral_moe.py | 365 ++++++++++++++++++
.../python/transformers/test_parity_moe.py | 13 +-
18 files changed, 1272 insertions(+), 246 deletions(-)
create mode 100644 onnxruntime/test/python/transformers/test_parity_mixtral_moe.py
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 5f0100fad95a2..32a4ca16b7824 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2931,8 +2931,8 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.MoE**
Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
- GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
- usually uses top 32 experts.
+ GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
+ usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).
#### Version
@@ -2946,9 +2946,11 @@ This version of the operator has been available since version 1 of the 'com.micr
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
k : int
Number of top experts to select from expert pool
+normalize_routing_weights : int
+Whether to normalize routing weights
-#### Inputs (4 - 6)
+#### Inputs (5 - 8)
- input : T
@@ -2957,12 +2959,16 @@ This version of the operator has been available since version 1 of the 'com.micr
- 2D input tensor with shape (num_rows, num_experts)
- fc1_experts_weights : T
- 3D input tensor with shape (num_experts, hidden_size, inter_size)
-- fc2_experts_weights : T
-- 3D input tensor with shape (num_experts, inter_size, hidden_size)
- fc1_experts_bias (optional) : T
- 2D optional input tensor with shape (num_experts, inter_size)
+- fc2_experts_weights : T
+- 3D input tensor with shape (num_experts, inter_size, hidden_size)
- fc2_experts_bias (optional) : T
- 2D optional input tensor with shape (num_experts, hidden_size)
+- fc3_experts_weights (optional) : T
+- 3D optional input tensor with shape (num_experts, hidden_size, inter_size)
+- fc3_experts_bias (optional) : T
+- 2D optional input tensor with shape (num_experts, inter_size)
#### Outputs
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index eddc3b7873d80..bca8e17b3dfd4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -861,7 +861,7 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
-|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
index 40a667ffd5d83..2efc37cf98010 100644
--- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#include
+
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
@@ -35,6 +37,7 @@ using namespace ONNX_NAMESPACE;
template
ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
+ ORT_ENFORCE(op_kernel_info.GetAttr("tensor_shards", &tensor_shards_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr("local_experts_start_index", &local_experts_start_index_).IsOK());
rank_to_experts_start_index_.resize(nccl_->Size());
// Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized.
@@ -55,27 +58,36 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
// Create a {Rank, ExpertsStartIndex} map on Host.
AutoDestoryCudaEvent cuda_event;
cudaEvent_t& copy_event = cuda_event.Get();
- ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
const Tensor* input = context->Input(0);
const Tensor* router_probs = context->Input(1);
const Tensor* fc1_experts_weights = context->Input(2);
- const Tensor* fc2_experts_weights = context->Input(3);
- const Tensor* fc1_experts_bias_optional = context->Input(4);
+ const Tensor* fc1_experts_bias_optional = context->Input(3);
+ const Tensor* fc2_experts_weights = context->Input(4);
const Tensor* fc2_experts_bias_optional = context->Input(5);
+ const Tensor* fc3_experts_weights_optional = context->Input(6);
+ const Tensor* fc3_experts_bias_optional = context->Input(7);
+
+ MoEParameters moe_params(tensor_shards_);
+ ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
+ fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
+ fc3_experts_bias_optional));
- MoEParameters moe_params;
- ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
- fc1_experts_bias_optional, fc2_experts_bias_optional));
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");
- ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm);
+ if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
+ ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
+ }
+
+ ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
+ fc3_experts_weights_optional != nullptr,
+ normalize_routing_weights_);
size_t ws_size =
- moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
- static_cast(moe_params.inter_size), static_cast(moe_params.num_experts),
- static_cast(k_));
+ moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
+ static_cast(moe_params.inter_size),
+ static_cast(moe_params.num_experts), static_cast(k_));
size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
@@ -93,19 +105,25 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
IAllocatorUniquePtr expert_for_source_row =
IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream);
- // fc1_scales and fc2_scales are used in quantized MoE
- const CudaT* fc1_scales_ptr = nullptr;
- const CudaT* fc2_scales_ptr = nullptr;
+ const CudaT* fc_scales_ptr = nullptr;
moe_runner.run_moe_fc(reinterpret_cast(input->template Data()),
reinterpret_cast(router_probs->template Data()),
reinterpret_cast(fc1_experts_weights->template Data()),
- std::move(fc1_scales_ptr),
+ std::move(fc_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast(fc1_experts_bias_optional->template Data()),
- activation_type_, reinterpret_cast(fc2_experts_weights->template Data()),
- std::move(fc2_scales_ptr), static_cast(moe_params.num_rows),
+ activation_type_,
+ fc3_experts_weights_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc3_experts_weights_optional->template Data()),
+ std::move(fc_scales_ptr),
+ fc3_experts_bias_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc3_experts_bias_optional->template Data()),
+ reinterpret_cast(fc2_experts_weights->template Data()),
+ std::move(fc_scales_ptr), static_cast(moe_params.num_rows),
static_cast(moe_params.hidden_size),
static_cast(moe_params.inter_size), static_cast(moe_params.num_experts),
static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_),
@@ -116,31 +134,54 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, input->Shape());
- size_t stride_count = moe_params.hidden_size;
- size_t stride_bytes = stride_count * sizeof(CudaT);
- int64_t total_past_rows = 0;
- int64_t total_covered_rows = 0;
- if (copy_event != nullptr) {
- CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
+ if (moe_params.parallel_type == MoEParallelType::None) {
+ fc2_output_bc = std::move(fc2_output);
}
- NCCL_RETURN_IF_ERROR(ncclGroupStart());
- for (int rank = 0; rank < nccl_->Size(); ++rank) {
- int64_t experts_start_index = rank_to_experts_start_index_[rank];
- moe_runner.get_total_rows_info(experts_start_index,
- moe_params.local_num_experts,
- total_past_rows,
- total_covered_rows);
- const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes;
- char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes;
- NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
- dst,
- total_covered_rows * stride_count,
+
+ if (moe_params.parallel_type == MoEParallelType::EPAndTP) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet");
+ }
+
+ if (moe_params.parallel_type == MoEParallelType::TP) {
+ ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());
+ NCCL_RETURN_IF_ERROR(ncclGroupStart());
+ NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast(fc2_output.get()),
+ reinterpret_cast(fc2_output_bc.get()),
+ fc2_output_size / sizeof(CudaT),
GetNcclDataType(input->DataType()),
- rank,
+ ncclSum,
nccl_->Comm(),
Stream(context)));
+ NCCL_RETURN_IF_ERROR(ncclGroupEnd());
+ }
+
+ if (moe_params.parallel_type == MoEParallelType::EP) {
+ size_t stride_count = moe_params.hidden_size;
+ size_t stride_bytes = stride_count * sizeof(CudaT);
+ int64_t total_past_rows = 0;
+ int64_t total_covered_rows = 0;
+ if (copy_event != nullptr) {
+ CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
+ }
+ NCCL_RETURN_IF_ERROR(ncclGroupStart());
+ for (int rank = 0; rank < nccl_->Size(); ++rank) {
+ int64_t experts_start_index = rank_to_experts_start_index_[rank];
+ moe_runner.get_total_rows_info(experts_start_index,
+ moe_params.local_num_experts,
+ total_past_rows,
+ total_covered_rows);
+ const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes;
+ char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes;
+ NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
+ dst,
+ total_covered_rows * stride_count,
+ GetNcclDataType(input->DataType()),
+ rank,
+ nccl_->Comm(),
+ Stream(context)));
+ }
+ NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}
- NCCL_RETURN_IF_ERROR(ncclGroupEnd());
ort_fastertransformer::finalize_moe_routing_kernelLauncher(
reinterpret_cast(fc2_output_bc.get()), reinterpret_cast(output->template MutableData()),
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
index 5ea4ae59c4020..827283a794dd6 100644
--- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
@@ -26,6 +26,7 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const;
int64_t local_experts_start_index_;
+ int64_t tensor_shards_;
std::vector rank_to_experts_start_index_;
};
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
index 78d206bf1d9bc..b18a70e899d1c 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
@@ -83,10 +83,16 @@ namespace ort_fastertransformer {
struct EpilogueOpBiasSilu {};
+struct EpilogueOpNoBiasSilu {};
+
struct EpilogueOpBiasReLU {};
+struct EpilogueOpNoBiasReLU {};
+
struct EpilogueOpBiasFtGelu {};
+struct EpilogueOpNoBiasFtGelu {};
+
struct EpilogueOpBias {};
struct EpilogueOpNoBias {};
@@ -101,6 +107,13 @@ struct Epilogue;
};
+template
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationSilu;
+};
+
template
struct Epilogue {
using Op = cutlass::epilogue::thread::LinearCombinationRelu;
};
+template
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationRelu;
+};
+
template
struct Epilogue {
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
@@ -116,6 +136,14 @@ struct Epilogue;
};
+template
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
+ cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
+ ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling,
+ cutlass::FloatRoundStyle::round_to_nearest, true>;
+};
+
template
struct Epilogue {
using Op = cutlass::epilogue::thread::LinearCombination
struct Epilogue {
using Op =
- cutlass::epilogue::thread::LinearCombination;
+ cutlass::epilogue::thread::LinearCombination<
+ ElementType, ElementsPerVectorAccess, ElementAccumulator,
+ ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
};
} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
index 60608f462fde5..e0f91ab806c85 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
@@ -42,8 +42,13 @@ class MoeGemmRunner {
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, ActivationType activation_type, cudaStream_t stream);
- void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert,
- int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream);
+ void moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert,
+ int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
+ ActivationType activation_type, cudaStream_t stream);
+
+ void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
+ int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
+ int num_experts, cudaStream_t stream);
private:
template
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
index a3dcf0da16b98..2a15fdfd1cc1a 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h
@@ -311,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig
template ::value>::type* = nullptr>
void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
- int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k,
- int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/,
+ int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n,
+ int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/,
int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) {
switch (gemm_config.tile_config) {
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
@@ -429,11 +429,47 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp
}
template
-void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C,
- int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n,
- int64_t gemm_k, int num_experts, cudaStream_t stream) {
- run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
+void MoeGemmRunner::moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales,
+ T* C, int64_t* total_rows_before_expert, int64_t total_rows,
+ int64_t gemm_n, int64_t gemm_k, int num_experts,
+ ActivationType activation_type, cudaStream_t stream) {
+ switch (activation_type) {
+ case ActivationType::Relu:
+ run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n,
+ gemm_k, num_experts, stream);
+ break;
+ case ActivationType::Gelu:
+ run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n,
+ gemm_k, num_experts, stream);
+ break;
+ case ActivationType::Silu:
+ run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n,
+ gemm_k, num_experts, stream);
+ break;
+ case ActivationType::Identity:
+ run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
+ num_experts, stream);
+ break;
+ case ActivationType::InvalidType:
+ ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM");
+ break;
+ default: {
+ ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM");
+ }
+ }
+}
+
+template
+void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases,
+ T* C, int64_t* total_rows_before_expert, int64_t total_rows,
+ int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) {
+ if (biases != nullptr) {
+ run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
num_experts, stream);
+ } else {
+ run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k,
+ num_experts, stream);
+ }
}
} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
index a5b47bcddefbc..5e6e484567988 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
@@ -30,7 +30,6 @@
#include "cutlass/array.h"
#include "cutlass/numeric_conversion.h"
-#include "cutlass/numeric_types.h"
#ifdef __GNUC__
#pragma GCC diagnostic pop
@@ -49,15 +48,14 @@
#endif
namespace ort_fastertransformer {
-
static constexpr int WARP_SIZE = 32;
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template
-__launch_bounds__(TPB) __global__
- void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) {
+__launch_bounds__(TPB) __global__ void moe_softmax(const T* input, const bool* finished, T* output,
+ const int num_cols) {
using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -108,14 +106,15 @@ __launch_bounds__(TPB) __global__
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
template
-__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, const int) {
+__launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int*, int*, int, int, bool) {
// Does not support pre-Kepler architectures
;
}
#else
template
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output,
- int* indices, int* source_rows, int num_experts, int k) {
+ int* indices, int* source_rows, int num_experts, int k,
+ bool normalize_routing_weights) {
using cub_kvp = cub::KeyValuePair;
using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage tmpStorage;
@@ -128,6 +127,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
const bool should_process_row = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
+ float output_row_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
@@ -155,6 +155,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? result_kvp.key : num_experts;
source_rows[idx] = k_idx * num_rows + block_row;
+
+ if (normalize_routing_weights && k_idx == k - 1) {
+#pragma unroll
+ for (int ki = 0; ki < k; ++ki) {
+ output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum);
+ }
+ }
}
__syncthreads();
}
@@ -178,7 +185,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
template
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topk_gating_softmax(const T* input, const bool* finished, T* output, int num_rows, int* indices,
- int* source_rows, int k) {
+ int* source_rows, int k, bool normalize_routing_weights) {
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
@@ -296,6 +303,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
+ float output_row_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
// First, each thread does the local argmax
float max_val = row_chunk[0];
@@ -336,8 +344,16 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = T(max_val);
+ output_row_sum = output_row_sum + static_cast(max_val);
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
+
+ if (normalize_routing_weights && k_idx == k - 1) {
+#pragma unroll
+ for (int ki = 0; ki < k; ++ki) {
+ output[idx - ki] = T(static_cast(output[idx - ki]) / output_row_sum);
+ }
+ }
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
@@ -370,7 +386,8 @@ struct TopkConstants {
template
void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row,
- int num_rows, int /*num_experts*/, int k, cudaStream_t stream) {
+ int num_rows, int /*num_experts*/, int k, bool normalize_routing_weights,
+ cudaStream_t stream) {
static constexpr unsigned long MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS);
@@ -382,61 +399,63 @@ void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topk_gating_softmax
- <<>>(input, finished, output, num_rows, indices, source_row, k);
+ <<>>(input, finished, output, num_rows, indices, source_row, k,
+ normalize_routing_weights);
}
template
void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output,
int* indices, int* source_row, int num_rows, int num_experts,
- int k, cudaStream_t stream) {
+ int k, bool normalize_routing_weights, cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
switch (num_experts) {
case 2: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 4: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 8: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 16: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 32: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 64: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 128: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
case 256: {
topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows,
- num_experts, k, stream);
+ num_experts, k, normalize_routing_weights, stream);
break;
}
default: {
static constexpr int TPB = 256;
moe_softmax<<>>(input, finished, softmax_temp_output, num_experts);
moe_top_k
- <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k);
+ <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k,
+ normalize_routing_weights);
}
}
}
@@ -521,25 +540,31 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i
}
template
-CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) {
- total_past_rows_ = 0;
- total_covered_rows_ = 0;
+CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version,
+ bool has_fc3,
+ bool normalize_routing_weights)
+ : has_fc3_(has_fc3),
+ total_past_rows_(0),
+ total_covered_rows_(0),
+ normalize_routing_weights_(normalize_routing_weights) {
moe_gemm_runner_.initialize(sm_version);
}
template
-size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows, const int hidden_size,
- const int inter_size, int num_experts,
- int k) {
- const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size));
- const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size));
- const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts));
- const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows));
- int num_softmax_outs = 0;
+size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_rows, const size_t hidden_size,
+ const size_t inter_size, size_t num_experts,
+ size_t k) {
+ total_covered_rows_ = k * num_rows;
+
+ const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
+ const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size);
+ const size_t padded_experts = pad_to_multiple_of_16(num_experts);
+ const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
+ size_t num_softmax_outs = 0;
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
- num_softmax_outs = static_cast(pad_to_multiple_of_16(num_rows * num_experts));
+ num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts);
}
// softmax output, permuted_rows and permuted_experts have moved to outside of moe kernel, allocate them
@@ -548,13 +573,13 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows,
total_ws_bytes += buf_size * sizeof(T); // permuted_data
total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
total_ws_bytes += num_softmax_outs * sizeof(T);
- const int bytes_for_fc1_result = interbuf_size * sizeof(T);
- const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)));
- sorter_.update_num_experts(num_experts);
+ const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T);
+ const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows));
+ sorter_.update_num_experts(static_cast(num_experts));
- int bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
+ size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
- int remaining_bytes = static_cast(pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result));
+ size_t remaining_bytes = pad_to_multiple_of_16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
@@ -563,13 +588,13 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(int num_rows,
}
template
-void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, int num_rows,
- const int hidden_size, const int inter_size,
- int num_experts, int k) {
- const int buf_size = static_cast(pad_to_multiple_of_16(k * num_rows * hidden_size));
- const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size));
- const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts));
- const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows));
+void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, size_t num_rows,
+ const size_t hidden_size, const size_t inter_size,
+ size_t num_experts, size_t k) {
+ const size_t buf_size = pad_to_multiple_of_16(k * num_rows * hidden_size);
+ const size_t interbuf_size = pad_to_multiple_of_16(k * num_rows * inter_size);
+ const size_t padded_experts = pad_to_multiple_of_16(num_experts);
+ const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows);
source_rows_ = (int*)ws_ptr;
permuted_rows_ = source_rows_ + num_moe_inputs;
@@ -578,28 +603,130 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr,
total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size);
- fc1_result_ = (T*)(total_rows_before_expert_ + padded_experts);
+ if (has_fc3_) {
+ fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts);
+ fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size);
+ } else {
+ fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts);
+ }
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
- softmax_out_ = (T*)(fc1_result_ + interbuf_size);
+ softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
}
+namespace {
+
+struct __align__(8) Half4 {
+ half2 x;
+ half2 y;
+};
+
+// TODO(wy): move to common header
+template
+struct T4;
+template <>
+struct T4 {
+ using Type = float4;
+};
+template <>
+struct T4 {
+ using Type = Half4;
+};
+
+template
+struct T2;
+template <>
+struct T2 {
+ using Type = float2;
+};
+template <>
+struct T2 {
+ using Type = half2;
+};
+
+inline __device__ float2 operator*(const float2 a, const float2 b) {
+ return make_float2(a.x * b.x, a.y * b.y);
+}
+
+inline __device__ float4 operator*(const float4 a, const float4 b) {
+ return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
+inline __device__ half operator*(const half a, const half b) {
+ return __float2half(__half2float(a) * __half2float(b));
+}
+
+inline __device__ half2 operator*(const half2 a, const half2 b) {
+ return make_half2(a.x * b.x, a.y * b.y);
+}
+#endif
+
+inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
+ Half4 result;
+ result.x = a.x * b.x;
+ result.y = a.y * b.y;
+ return result;
+#else
+ return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
+#endif
+}
+
+} // anonymous namespace
+
+template
+__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) {
+ int const tid = threadIdx.x;
+ int const token = blockIdx.x;
+
+ output = output + token * inter_size;
+ input = input + token * inter_size;
+ for (int i = tid; i < inter_size; i += blockDim.x) {
+ T fc1_value = input[i];
+ output[i] = fc1_value * output[i];
+ }
+}
+
+template
+void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) {
+ int const blocks = num_tokens;
+
+ if (inter_size & 3 == 0) {
+ using vec_type = typename T4::Type;
+ int const threads = std::min(inter_size / 4, 1024);
+ elementWiseMulKernel<<>>(reinterpret_cast(output),
+ reinterpret_cast(input),
+ inter_size / 4);
+ } else if (inter_size & 1 == 0) {
+ using vec_type = typename T2::Type;
+ int const threads = std::min(inter_size / 2, 1024);
+ elementWiseMulKernel<<>>(reinterpret_cast(output),
+ reinterpret_cast(input),
+ inter_size / 2);
+ } else {
+ int const threads = std::min(inter_size, 1024);
+ elementWiseMulKernel<<>>(output, input, inter_size);
+ }
+}
+
template
void CutlassMoeFCRunner::run_moe_fc(
const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales,
- const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights,
- const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts,
- int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result,
- const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row,
- int* expert_for_source_row, cudaStream_t stream) {
+ const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights,
+ const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales,
+ int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts,
+ int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows,
+ T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row,
+ cudaStream_t stream) {
static constexpr bool scales_required =
std::is_same::value || std::is_same::value;
- if constexpr (scales_required) {
+ if (scales_required) {
if (fc1_scales == nullptr) {
ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer");
} else if (fc2_scales == nullptr) {
@@ -613,9 +740,10 @@ void CutlassMoeFCRunner::run_moe_fc(
}
}
- configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k);
+ configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size),
+ static_cast(inter_size), static_cast(num_experts), static_cast(k));
topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
- source_rows_, num_rows, num_experts, k, stream);
+ source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream);
const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_,
@@ -634,15 +762,48 @@ void CutlassMoeFCRunner::run_moe_fc(
}
// expanded_active_expert_rows is not used
- moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size,
- fc1_expert_weights, fc1_scales, fc1_expert_biases,
- fc1_result_ + total_past_rows_ * inter_size,
- total_rows_before_expert_ + local_experts_start_index,
- expanded_active_expert_rows, inter_size, hidden_size,
- local_num_experts, fc1_activation_type, stream);
+ if (fc1_expert_biases != nullptr) {
+ moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size,
+ fc1_expert_weights, fc1_scales, fc1_expert_biases,
+ fc1_result_ + total_past_rows_ * inter_size,
+ total_rows_before_expert_ + local_experts_start_index,
+ expanded_active_expert_rows, inter_size, hidden_size,
+ local_num_experts, fc1_activation_type, stream);
+ } else {
+ moe_gemm_runner_.moe_gemm_act(permuted_data_ + total_past_rows_ * hidden_size,
+ fc1_expert_weights, fc1_scales,
+ fc1_result_ + total_past_rows_ * inter_size,
+ total_rows_before_expert_ + local_experts_start_index,
+ expanded_active_expert_rows, inter_size, hidden_size,
+ local_num_experts, fc1_activation_type, stream);
+ }
+
+ if (has_fc3_) {
+ if (scales_required) {
+ if (fc3_scales == nullptr) {
+ ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for third matmul is a null pointer");
+ }
+ } else {
+ if (fc3_scales != nullptr) {
+ ORT_THROW("[FT Error][Run MoE FC] Scales are ignored for fp32/fp16/bf16 but received scale for FC3");
+ }
+ }
+ if (fc3_expert_weights == nullptr) {
+ ORT_THROW("[FT Error][Run MoE FC] FC3 weights are null");
+ }
+ moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size,
+ fc3_expert_weights, fc3_scales, fc3_expert_biases,
+ fc3_result_ + total_past_rows_ * inter_size,
+ total_rows_before_expert_ + local_experts_start_index,
+ expanded_active_expert_rows, inter_size, hidden_size,
+ local_num_experts, stream);
+
+ elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size,
+ static_cast(inter_size), static_cast(total_covered_rows_), stream);
+ }
moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size,
- fc2_expert_weights, fc2_scales,
+ fc2_expert_weights, fc2_scales, nullptr,
fc2_result + total_past_rows_ * hidden_size,
total_rows_before_expert_ + local_experts_start_index,
expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream);
@@ -651,14 +812,16 @@ void CutlassMoeFCRunner::run_moe_fc(
template
void CutlassMoeFCRunner::run_moe_fc(
const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales,
- const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights,
- const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts,
- int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales,
+ const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc3_expert_weights,
+ const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales,
+ int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts,
+ int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales,
int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) {
run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type,
- fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts,
- local_experts_start_index, k, workspace_ptr, fc2_result, nullptr, num_rows, expert_scales,
- expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream);
+ fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size,
+ inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result,
+ nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row,
+ stream);
}
template
@@ -811,9 +974,10 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T*
const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols;
const int expert_idx = expert_for_source_row[k_offset];
- const T* bias_ptr = bias + expert_idx * cols;
+ const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr;
- thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + bias_ptr[tid]);
+ thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] +
+ (bias_ptr ? bias_ptr[tid] : T(0)));
}
reduced_row_ptr[tid] = thread_output;
}
@@ -866,9 +1030,9 @@ void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* red
// ========================= TopK Softmax specializations ===========================
template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int,
- int, int, cudaStream_t);
+ int, int, bool, cudaStream_t);
template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int,
- int, int, cudaStream_t);
+ int, int, bool, cudaStream_t);
// ==================== Variable batched GEMM specializations ==================================
template class CutlassMoeFCRunner;
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
index 5cc2a3f79f003..5eef6f95f4820 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
+++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
@@ -24,6 +24,8 @@
#include "core/common/common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
+#include "cutlass/numeric_types.h"
+
using namespace onnxruntime;
namespace ort_fastertransformer {
@@ -107,12 +109,13 @@ template
class CutlassMoeFCRunner {
public:
- CutlassMoeFCRunner(int sm_version);
+ CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
- size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k);
+ size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k);
void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights,
const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type,
+ const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases,
const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size,
int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k,
char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row,
@@ -120,6 +123,7 @@ class CutlassMoeFCRunner {
void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights,
const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type,
+ const WeightType* fc3_expert_weights, const T* fc3_scales, const T* fc3_expert_biases,
const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size,
int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k,
char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales,
@@ -135,7 +139,8 @@ class CutlassMoeFCRunner {
int64_t& total_covered_rows);
private:
- void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k);
+ void configure_ws_ptrs(char* ws_ptr, size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts,
+ size_t k);
private:
CubKeyValueSorter sorter_;
@@ -152,12 +157,17 @@ class CutlassMoeFCRunner {
int64_t* total_rows_before_expert_;
T* fc1_result_;
+ T* fc3_result_;
+
+ bool has_fc3_;
+ bool normalize_routing_weights_;
// Cuda events
contrib::cuda::AutoDestoryCudaEvent cuda_event_;
int64_t total_past_rows_;
int64_t total_covered_rows_;
+
// TODO: use pinned memory
std::vector total_rows_before_expert_host_;
};
@@ -165,11 +175,11 @@ class CutlassMoeFCRunner {
template
class CutlassMoeFCRunner::value>> {
public:
- CutlassMoeFCRunner(int sm_version);
+ CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights);
- size_t getWorkspaceSize(int num_rows, int hidden_size, int inter_size, int num_experts, int k) {
+ size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
return 0;
}
};
-} // namespace ort_fastertransformer
\ No newline at end of file
+} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc
index 3f26a274109ad..b13aab959fc48 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc
@@ -39,13 +39,16 @@ Status MoE::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input(0);
const Tensor* router_probs = context->Input(1);
const Tensor* fc1_experts_weights = context->Input(2);
- const Tensor* fc2_experts_weights = context->Input(3);
- const Tensor* fc1_experts_bias_optional = context->Input(4);
+ const Tensor* fc1_experts_bias_optional = context->Input(3);
+ const Tensor* fc2_experts_weights = context->Input(4);
const Tensor* fc2_experts_bias_optional = context->Input(5);
+ const Tensor* fc3_experts_weights_optional = context->Input(6);
+ const Tensor* fc3_experts_bias_optional = context->Input(7);
MoEParameters moe_params;
- ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
- fc1_experts_bias_optional, fc2_experts_bias_optional));
+ ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
+ fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
+ fc3_experts_bias_optional));
typedef typename ToCudaType::MappedType CudaT;
auto stream = context->GetComputeStream();
@@ -53,12 +56,14 @@ Status MoE::ComputeInternal(OpKernelContext* context) const {
auto& device_prop = GetDeviceProp();
const int sm = device_prop.major * 10 + device_prop.minor;
- ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm);
+ ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
+ fc3_experts_weights_optional != nullptr,
+ normalize_routing_weights_);
size_t ws_size =
- moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
- static_cast(moe_params.inter_size), static_cast(moe_params.num_experts),
- static_cast(k_));
+ moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
+ static_cast(moe_params.inter_size),
+ static_cast(moe_params.num_experts), static_cast(k_));
size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
@@ -77,26 +82,37 @@ Status MoE::ComputeInternal(OpKernelContext* context) const {
IAllocatorUniquePtr expert_for_source_row =
IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream);
- // fc1_scales and fc2_scales are used in quantized MoE
- const CudaT* fc1_scales_ptr = nullptr;
- const CudaT* fc2_scales_ptr = nullptr;
-
+ const CudaT* fc_scales_ptr = nullptr;
moe_runner.run_moe_fc(reinterpret_cast(input->template Data()),
reinterpret_cast(router_probs->template Data()),
- reinterpret_cast(fc1_experts_weights->template Data()),
- std::move(fc1_scales_ptr),
+ reinterpret_cast(fc1_experts_weights->DataRaw()),
+ fc_scales_ptr,
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast(fc1_experts_bias_optional->template Data()),
- activation_type_, reinterpret_cast(fc2_experts_weights->template Data()),
- std::move(fc2_scales_ptr), static_cast(moe_params.num_rows),
- static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size),
- static_cast(moe_params.num_experts), static_cast(moe_params.local_num_experts),
- 0 /*local_experts_start_index_ used in sharded MoE*/, static_cast(k_),
- reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()),
+ activation_type_,
+ fc3_experts_weights_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc3_experts_weights_optional->DataRaw()),
+ fc_scales_ptr,
+ fc3_experts_bias_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc3_experts_bias_optional->template Data()),
+ reinterpret_cast(fc2_experts_weights->DataRaw()),
+ fc_scales_ptr,
+ static_cast(moe_params.num_rows),
+ static_cast(moe_params.hidden_size),
+ static_cast