Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add QMoE #20108

Merged
merged 19 commits into from
Mar 29, 2024
Merged

add QMoE #20108

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ set(contrib_ops_excluded_files
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/moe_quantization.h"
"quantization/moe_quantization.cc"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
Expand Down
64 changes: 64 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Do not modify directly.*
* <a href="#com.microsoft.QLinearSigmoid">com.microsoft.QLinearSigmoid</a>
* <a href="#com.microsoft.QLinearSoftmax">com.microsoft.QLinearSoftmax</a>
* <a href="#com.microsoft.QLinearWhere">com.microsoft.QLinearWhere</a>
* <a href="#com.microsoft.QMoE">com.microsoft.QMoE</a>
* <a href="#com.microsoft.QOrderedAttention">com.microsoft.QOrderedAttention</a>
* <a href="#com.microsoft.QOrderedGelu">com.microsoft.QOrderedGelu</a>
* <a href="#com.microsoft.QOrderedLayerNormalization">com.microsoft.QOrderedLayerNormalization</a>
Expand Down Expand Up @@ -4261,6 +4262,69 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.QMoE"></a><a name="com.microsoft.qmoe">**com.microsoft.QMoE**</a>

Int4 MoE

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>activation_type</tt> : string</dt>
<dd>Activation function to use. Choose from relu, gelu, silu and identity. Default is relu</dd>
<dt><tt>k</tt> : int</dt>
<dd>Number of top experts to select from expert pool</dd>
<dt><tt>normalize_routing_weights</tt> : int</dt>
<dd>Whether to normalize routing weights</dd>
</dl>

#### Inputs (7 - 11)

<dl>
<dt><tt>input</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>router_probs</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
<dt><tt>fc1_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc1_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc2_experts_weights</tt> : T1</dt>
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size / 2)</dd>
<dt><tt>fc2_scales</tt> : T</dt>
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)</dd>
<dt><tt>fc3_scales</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : T</dt>
<dd>2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output types to float or float16 tensors.</dd>
<dt><tt>T1</tt> : tensor(uint8)</dt>
<dd>Constrain weights type to uint8 tensors.</dd>
</dl>


### <a name="com.microsoft.QOrderedAttention"></a><a name="com.microsoft.qorderedattention">**com.microsoft.QOrderedAttention**</a>

Quantized version of simplified Multi-Head Self Attention(using int8 with specific matrix Layout).
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ Do not modify directly.*
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* relative_position_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* relative_position_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* relative_position_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
126 changes: 48 additions & 78 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,15 @@ namespace cuda {

#if defined(ORT_USE_NCCL)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ShardedMoE, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.MayInplace(0, 0) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ShardedMoE, kMSDomain, 1, T, kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ShardedMoE<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

using namespace ONNX_NAMESPACE;

template <typename T>
ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("tensor_shards", &tensor_shards_).IsOK());
Expand Down Expand Up @@ -69,25 +61,23 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc3_experts_bias_optional = context->Input<Tensor>(7);

MoEParameters moe_params(tensor_shards_);
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional));
MoEQuantType quant_type = MoEQuantType::None;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
fc3_experts_weights_optional, fc3_experts_bias_optional));

ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");

if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
}

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
fc3_experts_weights_optional != nullptr,
ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);

size_t ws_size =
moe_runner.getWorkspaceSize(static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
static_cast<size_t>(moe_params.inter_size),
static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));
size_t ws_size = moe_runner.getWorkspaceSize(
static_cast<size_t>(moe_params.num_rows), static_cast<size_t>(moe_params.hidden_size),
static_cast<size_t>(moe_params.inter_size), static_cast<size_t>(moe_params.num_experts), static_cast<size_t>(k_));

size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
Expand All @@ -107,30 +97,29 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {

const CudaT* fc_scales_ptr = nullptr;

moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
std::move(fc_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
activation_type_,
fc3_experts_weights_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
std::move(fc_scales_ptr),
fc3_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
std::move(fc_scales_ptr), static_cast<int>(moe_params.num_rows),
static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()),
reinterpret_cast<CudaT*>(fc2_output.get()), reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));
moe_runner.run_moe_fc(
reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()), std::move(fc_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
activation_type_,
fc3_experts_weights_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_weights_optional->template Data<T>()),
std::move(fc_scales_ptr),
fc3_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc3_experts_bias_optional->template Data<T>()),
reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()), std::move(fc_scales_ptr),
static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()), reinterpret_cast<CudaT*>(fc2_output.get()),
reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));

Tensor* output = context->Output(0, input->Shape());

Expand All @@ -146,12 +135,8 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());
NCCL_RETURN_IF_ERROR(ncclGroupStart());
NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast<const char*>(fc2_output.get()),
reinterpret_cast<char*>(fc2_output_bc.get()),
fc2_output_size / sizeof(CudaT),
GetNcclDataType(input->DataType()),
ncclSum,
nccl_->Comm(),
Stream(context)));
reinterpret_cast<char*>(fc2_output_bc.get()), fc2_output_size / sizeof(CudaT),
GetNcclDataType(input->DataType()), ncclSum, nccl_->Comm(), Stream(context)));
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}

Expand All @@ -166,19 +151,12 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
moe_runner.get_total_rows_info(experts_start_index,
moe_params.local_num_experts,
total_past_rows,
moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
dst,
total_covered_rows * stride_count,
GetNcclDataType(input->DataType()),
rank,
nccl_->Comm(),
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclBroadcast(src, dst, total_covered_rows * stride_count,
GetNcclDataType(input->DataType()), rank, nccl_->Comm(), Stream(context)));
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}
Expand All @@ -197,8 +175,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
}

template <typename T>
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
OpKernelContext* context,
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, OpKernelContext* context,
cudaEvent_t& cuda_event) const {
if (rank_to_experts_start_index_[0] != std::numeric_limits<int64_t>::min()) {
return Status::OK();
Expand All @@ -215,23 +192,16 @@ Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
IAllocator::MakeUniquePtr<IndexType>(allocator, nccl_->Size(), false, stream);

// Only happens in the first run.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(),
&local_experts_start_index_,
IndexTypeSize,
cudaMemcpyHostToDevice,
Stream(context)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), &local_experts_start_index_, IndexTypeSize,
cudaMemcpyHostToDevice, Stream(context)));
NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast<const char*>(experts_start_index_d.get()),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()),
1,
GetNcclDataType(DataTypeImpl::GetType<IndexType>()),
nccl_->Comm(),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()), 1,
GetNcclDataType(DataTypeImpl::GetType<IndexType>()), nccl_->Comm(),
Stream(context)));
// The const_cast<> violates the const modifier to make sure the synchronization happens only once per session.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast<int64_t*>(rank_to_experts_start_index_.data()),
rank_to_experts_start_index_d.get(),
nccl_->Size() * IndexTypeSize,
cudaMemcpyDeviceToHost,
Stream(context)));
rank_to_experts_start_index_d.get(), nccl_->Size() * IndexTypeSize,
cudaMemcpyDeviceToHost, Stream(context)));

CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming));
CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context)));
Expand Down
Loading
Loading