Skip to content

Commit

Permalink
add QMoE (#20108)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
1. Introduce latest cutlass extension from TRTLLM that gives us cutlass
upgrade(to 3.4) opportunity from MoE side.
2. Fix Windows build issue
3. Add Int4 MoE op and ut



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems authored Mar 29, 2024
1 parent 2092beb commit 1791971
Show file tree
Hide file tree
Showing 60 changed files with 10,748 additions and 1,497 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ set(contrib_ops_excluded_files
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/moe_quantization.h"
"quantization/moe_quantization.cc"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
Expand Down
64 changes: 64 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Do not modify directly.*
* <a href="#com.microsoft.QLinearSigmoid">com.microsoft.QLinearSigmoid</a>
* <a href="#com.microsoft.QLinearSoftmax">com.microsoft.QLinearSoftmax</a>
* <a href="#com.microsoft.QLinearWhere">com.microsoft.QLinearWhere</a>
* <a href="#com.microsoft.QMoE">com.microsoft.QMoE</a>
* <a href="#com.microsoft.QOrderedAttention">com.microsoft.QOrderedAttention</a>
* <a href="#com.microsoft.QOrderedGelu">com.microsoft.QOrderedGelu</a>
* <a href="#com.microsoft.QOrderedLayerNormalization">com.microsoft.QOrderedLayerNormalization</a>
Expand Down Expand Up @@ -4261,6 +4262,69 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>

Int4 MoE

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
</dl>

#### Inputs (7 - 11)

<dl>
<dt><tt>input</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc3_scales</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
<dt><tt>T1</tt> : tensor(uint8)</dt>
<dd>Constrain weights type to uint8 tensors.</dd>
</dl>


### <a name="com.microsoft.QOrderedAttention"></a><a name="com.microsoft.qorderedattention">**com.microsoft.QOrderedAttention**</a>

Quantized version of simplified Multi-Head Self Attention(using int8 with specific matrix Layout).
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ Do not modify directly.*
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* relative_position_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* relative_position_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* relative_position_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
126 changes: 48 additions & 78 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,15 @@ namespace cuda {

#if defined(ORT_USE_NCCL)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ShardedMoE, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.MayInplace(0, 0) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ShardedMoE, kMSDomain, 1, T, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ShardedMoE<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

using namespace ONNX_NAMESPACE;

template <typename T>
ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("tensor_shards", &tensor_shards_).IsOK());
Expand Down Expand Up @@ -69,25 +61,23 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

MoEParameters moe_params(tensor_shards_);
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional));
MoEQuantType quant_type = MoEQuantType::None;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));

ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
}

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
fc3_experts_weights_optional != nullptr,
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);

size_t ws_size =
moe_runner.getWorkspaceSize(static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
static_cast<size_t>(moe_params.inter_size),
static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));

size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
Expand All @@ -107,30 +97,29 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {

const CudaT* fc_scales_ptr = nullptr;

moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
std::move(fc_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
activation_type_,
fc3_experts_weights_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
std::move(fc_scales_ptr),
fc3_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
std::move(fc_scales_ptr), static_cast<int>(moe_params.num_rows),
static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()),
reinterpret_cast<CudaT*>(fc2_output.get()), reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));
moe_runner.run_moe_fc(
reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()), std::move(fc_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
activation_type_,
fc3_experts_weights_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
std::move(fc_scales_ptr),
fc3_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()), std::move(fc_scales_ptr),
static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()), reinterpret_cast<CudaT*>(fc2_output.get()),
reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));

Tensor* output = context->Output(0, input->Shape());

Expand All @@ -146,12 +135,8 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());
NCCL_RETURN_IF_ERROR(ncclGroupStart());
NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast<const char*>(fc2_output.get()),
reinterpret_cast<char*>(fc2_output_bc.get()),
fc2_output_size / sizeof(CudaT),
GetNcclDataType(input->DataType()),
ncclSum,
nccl_->Comm(),
Stream(context)));
reinterpret_cast<char*>(fc2_output_bc.get()), fc2_output_size / sizeof(CudaT),
GetNcclDataType(input->DataType()), ncclSum, nccl_->Comm(), Stream(context)));
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}

Expand All @@ -166,19 +151,12 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
moe_runner.get_total_rows_info(experts_start_index,
moe_params.local_num_experts,
total_past_rows,
moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
dst,
total_covered_rows * stride_count,
GetNcclDataType(input->DataType()),
rank,
nccl_->Comm(),
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclBroadcast(src, dst, total_covered_rows * stride_count,
GetNcclDataType(input->DataType()), rank, nccl_->Comm(), Stream(context)));
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}
Expand All @@ -197,8 +175,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
}

template <typename T>
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
OpKernelContext* context,
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, OpKernelContext* context,
cudaEvent_t& cuda_event) const {
if (rank_to_experts_start_index_[0] != std::numeric_limits<int64_t>::min()) {
return Status::OK();
Expand All @@ -215,23 +192,16 @@ Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
IAllocator::MakeUniquePtr<IndexType>(allocator, nccl_->Size(), false, stream);

// Only happens in the first run.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(),
&local_experts_start_index_,
IndexTypeSize,
cudaMemcpyHostToDevice,
Stream(context)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), &local_experts_start_index_, IndexTypeSize,
cudaMemcpyHostToDevice, Stream(context)));
NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast<const char*>(experts_start_index_d.get()),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()),
1,
GetNcclDataType(DataTypeImpl::GetType<IndexType>()),
nccl_->Comm(),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()), 1,
GetNcclDataType(DataTypeImpl::GetType<IndexType>()), nccl_->Comm(),
Stream(context)));
// The const_cast<> violates the const modifier to make sure the synchronization happens only once per session.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast<int64_t*>(rank_to_experts_start_index_.data()),
rank_to_experts_start_index_d.get(),
nccl_->Size() * IndexTypeSize,
cudaMemcpyDeviceToHost,
Stream(context)));
rank_to_experts_start_index_d.get(), nccl_->Size() * IndexTypeSize,
cudaMemcpyDeviceToHost, Stream(context)));

CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming));
CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context)));
Expand Down
Loading

0 comments on commit 1791971

Please sign in to comment.