Skip to content

Commit

Permalink
[MoE] Add TP and Mixtral MoE (#19945)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

1.Support Tensor Parallelism in ShardedMoE.
2.Make necessary code changes to support Mixtral MoE.
3.Fix a bug related to using IOBinding in test script.
4.Fix the input size limitation

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems authored Mar 20, 2024
1 parent 3dfe4a5 commit 6ff31e0
Show file tree
Hide file tree
Showing 18 changed files with 1,272 additions and 246 deletions.
16 changes: 11 additions & 5 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2931,8 +2931,8 @@ This version of the operator has been available since version 1 of the 'com.micr
### <a name="com.microsoft.MoE"></a><a name="com.microsoft.moe">**com.microsoft.MoE**</a>

Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
usually uses top 32 experts.
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).


#### Version
Expand All @@ -2946,9 +2946,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
</dl>

#### Inputs (4 - 6)
#### Inputs (5 - 8)

<dl>
<dt><tt>input</tt> : T</dt>
Expand All @@ -2957,12 +2959,16 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
</dl>

#### Outputs
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ Do not modify directly.*
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
113 changes: 77 additions & 36 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <utility>

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
Expand Down Expand Up @@ -35,6 +37,7 @@ using namespace ONNX_NAMESPACE;

template <typename T>
ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("tensor_shards", &tensor_shards_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("local_experts_start_index", &local_experts_start_index_).IsOK());
rank_to_experts_start_index_.resize(nccl_->Size());
// Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized.
Expand All @@ -55,27 +58,36 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
// Create a {Rank, ExpertsStartIndex} map on Host.
AutoDestoryCudaEvent cuda_event;
cudaEvent_t& copy_event = cuda_event.Get();
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));

const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
const Tensor* fc2_experts_weights = context->Input<Tensor>(3);
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(3);
const Tensor* fc2_experts_weights = context->Input<Tensor>(4);
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(5);
const Tensor* fc3_experts_weights_optional = context->Input<Tensor>(6);
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

MoEParameters moe_params(tensor_shards_);
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional));

MoEParameters moe_params;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
fc1_experts_bias_optional, fc2_experts_bias_optional));
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm);
if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
}

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);

size_t ws_size =
moe_runner.getWorkspaceSize(static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(k_));
moe_runner.getWorkspaceSize(static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
static_cast<size_t>(moe_params.inter_size),
static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));

size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
Expand All @@ -93,19 +105,25 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
IAllocatorUniquePtr<void> expert_for_source_row =
IAllocator::MakeUniquePtr<void>(allocator, expert_for_source_row_size, false, stream);

// fc1_scales and fc2_scales are used in quantized MoE
const CudaT* fc1_scales_ptr = nullptr;
const CudaT* fc2_scales_ptr = nullptr;
const CudaT* fc_scales_ptr = nullptr;

moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
std::move(fc1_scales_ptr),
std::move(fc_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
activation_type_, reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
std::move(fc2_scales_ptr), static_cast<int>(moe_params.num_rows),
activation_type_,
fc3_experts_weights_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
std::move(fc_scales_ptr),
fc3_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
std::move(fc_scales_ptr), static_cast<int>(moe_params.num_rows),
static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
Expand All @@ -116,31 +134,54 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {

Tensor* output = context->Output(0, input->Shape());

size_t stride_count = moe_params.hidden_size;
size_t stride_bytes = stride_count * sizeof(CudaT);
int64_t total_past_rows = 0;
int64_t total_covered_rows = 0;
if (copy_event != nullptr) {
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
if (moe_params.parallel_type == MoEParallelType::None) {
fc2_output_bc = std::move(fc2_output);
}
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
moe_runner.get_total_rows_info(experts_start_index,
moe_params.local_num_experts,
total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
dst,
total_covered_rows * stride_count,

if (moe_params.parallel_type == MoEParallelType::EPAndTP) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet");
}

if (moe_params.parallel_type == MoEParallelType::TP) {
ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());
NCCL_RETURN_IF_ERROR(ncclGroupStart());
NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast<const char*>(fc2_output.get()),
reinterpret_cast<char*>(fc2_output_bc.get()),
fc2_output_size / sizeof(CudaT),
GetNcclDataType(input->DataType()),
rank,
ncclSum,
nccl_->Comm(),
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}

if (moe_params.parallel_type == MoEParallelType::EP) {
size_t stride_count = moe_params.hidden_size;
size_t stride_bytes = stride_count * sizeof(CudaT);
int64_t total_past_rows = 0;
int64_t total_covered_rows = 0;
if (copy_event != nullptr) {
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
}
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
moe_runner.get_total_rows_info(experts_start_index,
moe_params.local_num_experts,
total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
dst,
total_covered_rows * stride_count,
GetNcclDataType(input->DataType()),
rank,
nccl_->Comm(),
Stream(context)));
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());

ort_fastertransformer::finalize_moe_routing_kernelLauncher(
reinterpret_cast<CudaT*>(fc2_output_bc.get()), reinterpret_cast<CudaT*>(output->template MutableData<T>()),
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const;

int64_t local_experts_start_index_;
int64_t tensor_shards_;
std::vector<int64_t> rank_to_experts_start_index_;
};

Expand Down
33 changes: 31 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,16 @@ namespace ort_fastertransformer {

struct EpilogueOpBiasSilu {};

struct EpilogueOpNoBiasSilu {};

struct EpilogueOpBiasReLU {};

struct EpilogueOpNoBiasReLU {};

struct EpilogueOpBiasFtGelu {};

struct EpilogueOpNoBiasFtGelu {};

struct EpilogueOpBias {};

struct EpilogueOpNoBias {};
Expand All @@ -101,13 +107,27 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBiasSilu> {
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU> {
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBiasReLU> {
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu> {
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
Expand All @@ -116,6 +136,14 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
cutlass::FloatRoundStyle::round_to_nearest, true>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBiasFtGelu> {
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling,
cutlass::FloatRoundStyle::round_to_nearest, true>;
};

template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias> {
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
Expand All @@ -126,8 +154,9 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpNoBias> {
using Op =
cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, cutlass::epilogue::thread::ScaleType::Default>;
cutlass::epilogue::thread::LinearCombination<
ElementType, ElementsPerVectorAccess, ElementAccumulator,
ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
};

} // namespace ort_fastertransformer
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,13 @@ class MoeGemmRunner {
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, ActivationType activation_type, cudaStream_t stream);

void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert,
int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream);
void moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert,
int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts,
ActivationType activation_type, cudaStream_t stream);

void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C,
int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k,
int num_experts, cudaStream_t stream);

private:
template <typename EpilogueTag>
Expand Down
Loading

0 comments on commit 6ff31e0

Please sign in to comment.