From 17919717b57c1246a9f2629b7d3e9d523d80dcd6 Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Fri, 29 Mar 2024 10:24:19 -0700
Subject: [PATCH] add QMoE (#20108)
### Description
1. Introduce latest cutlass extension from TRTLLM that gives us cutlass
upgrade(to 3.4) opportunity from MoE side.
2. Fix Windows build issue
3. Add Int4 MoE op and ut
### Motivation and Context
---
cmake/onnxruntime_rocm_hipify.cmake | 2 +
docs/ContribOperators.md | 64 +
docs/OperatorKernels.md | 1 +
.../cuda/collective/sharded_moe.cc | 126 +-
.../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 +
.../cuda/moe/cutlass_extensions/arch/mma.h | 110 ++
.../compute_occupancy.h | 21 +-
.../epilogue/thread/fused_activations.h} | 76 +-
.../epilogue_per_row_per_col_scale.h | 306 ++++
.../threadblock/epilogue_tensor_op_int32.h | 247 +++
.../epilogue_helpers.h | 129 +-
.../gemm/device/gemm_universal_base_compat.h | 384 +++++
.../gemm/device/splitk_gemm_grouped.h | 476 ++++++
.../gemm/kernel/default_fpA_intB_traits.h} | 61 +-
.../gemm/kernel/default_int8_traits.h | 51 +
.../gemm/kernel/default_splitk_gemm_grouped.h | 206 +++
.../gemm/kernel/fpA_intB_gemm.h | 513 ++++++
.../gemm/kernel/gemm_moe_problem_visitor.h | 66 +
.../gemm/kernel/gemm_with_epilogue_visitor.h | 516 ++++++
.../gemm/kernel/mixed_gemm_B_layout.h | 126 ++
.../gemm/kernel/moe_cutlass_kernel.h | 471 +++++
.../gemm/kernel}/moe_problem_visitor.h | 54 +-
.../gemm/kernel/splitk_gemm_grouped.h | 464 +++++
.../gemm/threadblock/default_dq_mma.h | 120 ++
.../threadblock/default_dq_mma_multistage.h | 289 ++++
.../threadblock/default_dq_mma_pipelined.h | 245 +++
.../gemm/threadblock/default_mma.h | 283 +++
.../gemm/threadblock/default_mma_bf16.h | 345 ++++
.../gemm/threadblock/dq_mma_base.h | 237 +++
.../gemm/threadblock/dq_mma_multistage.h | 107 ++
.../dq_mma_multistage_finegrained.h | 634 +++++++
.../threadblock/dq_mma_multistage_percol.h | 586 +++++++
.../gemm/threadblock/dq_mma_pipelined.h | 379 +++++
.../gemm/warp/default_mma_tensor_op.h | 103 ++
.../warp/mma_tensorop_compute_B_with_f16.h | 283 +++
.../gemm/warp/mma_tensorop_dequantizer.h | 534 ++++++
.../moe/cutlass_extensions/gemm_configs.h | 125 ++
.../interleaved_numeric_conversion.h | 392 +++++
.../tile_interleaved_layout.h | 2 +-
.../fine_grained_scale_zero_iterator.h | 222 +++
.../cutlass_extensions/weight_only_quant_op.h | 50 +
.../cuda/moe/ft_moe/cutlass_heuristic.cc | 4 +-
.../cuda/moe/ft_moe/cutlass_heuristic.h | 2 +-
.../cuda/moe/ft_moe/ft_gemm_configs.h | 58 -
.../cuda/moe/ft_moe/moe_cutlass_kernel.h | 463 -----
.../cuda/moe/ft_moe/moe_gemm_kernels.h | 10 +-
.../moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu | 11 +-
.../moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu | 30 +
.../moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu | 12 +-
.../moe/ft_moe/moe_gemm_kernels_template.h | 139 +-
.../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 168 +-
onnxruntime/contrib_ops/cuda/moe/moe.cc | 92 +-
onnxruntime/contrib_ops/cuda/moe/moe_base.h | 94 +-
.../cuda/quantization/moe_quantization.cc | 143 ++
.../cuda/quantization/moe_quantization.h | 25 +
.../core/graph/contrib_ops/contrib_defs.cc | 58 +
onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 +
onnxruntime/test/contrib_ops/moe_test.cc | 1516 ++++++++++++-----
.../transformers/test_parity_mixtral_moe.py | 6 +-
.../python/transformers/test_parity_moe.py | 4 +-
60 files changed, 10748 insertions(+), 1497 deletions(-)
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h
rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions}/compute_occupancy.h (62%)
rename onnxruntime/contrib_ops/cuda/moe/{ft_moe/gemm_moe_problem_visitor.h => cutlass_extensions/epilogue/thread/fused_activations.h} (57%)
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions}/epilogue_helpers.h (57%)
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h
rename onnxruntime/contrib_ops/cuda/moe/{ft_moe/layout_traits_helper.h => cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h} (71%)
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h
rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions/gemm/kernel}/moe_problem_visitor.h (79%)
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h
rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions}/tile_interleaved_layout.h (98%)
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h
delete mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h
delete mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
create mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu
create mode 100644 onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
create mode 100644 onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index cadb06bb38707..0051f241e4f9b 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -60,6 +60,8 @@ set(contrib_ops_excluded_files
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
+ "quantization/moe_quantization.h"
+ "quantization/moe_quantization.cc"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 32a4ca16b7824..9b45cc02708d6 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -78,6 +78,7 @@ Do not modify directly.*
* com.microsoft.QLinearSigmoid
* com.microsoft.QLinearSoftmax
* com.microsoft.QLinearWhere
+ * com.microsoft.QMoE
* com.microsoft.QOrderedAttention
* com.microsoft.QOrderedGelu
* com.microsoft.QOrderedLayerNormalization
@@ -4261,6 +4262,69 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.QMoE**
+
+ Int4 MoE
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- activation_type : string
+- Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+- k : int
+- Number of top experts to select from expert pool
+- normalize_routing_weights : int
+- Whether to normalize routing weights
+
+
+#### Inputs (7 - 11)
+
+
+- input : T
+- 2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+- router_probs : T
+- 2D input tensor with shape (num_rows, num_experts)
+- fc1_experts_weights : T1
+- 3D input tensor with shape (num_experts, hidden_size, inter_size / 2)
+- fc1_scales : T
+- 2D input tensor with shape (num_experts, inter_size)
+- fc1_experts_bias (optional) : T
+- 2D optional input tensor with shape (num_experts, inter_size)
+- fc2_experts_weights : T1
+- 3D input tensor with shape (num_experts, inter_size, hidden_size / 2)
+- fc2_scales : T
+- 2D input tensor with shape (num_experts, hidden_size)
+- fc2_experts_bias (optional) : T
+- 2D optional input tensor with shape (num_experts, hidden_size)
+- fc3_experts_weights (optional) : T1
+- 3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)
+- fc3_scales (optional) : T
+- 2D optional input tensor with shape (num_experts, inter_size)
+- fc3_experts_bias (optional) : T
+- 2D optional input tensor with shape (num_experts, inter_size)
+
+
+#### Outputs
+
+
+- output : T
+- 2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
+
+#### Type Constraints
+
+
+- T : tensor(float16)
+- Constrain input and output types to float or float16 tensors.
+- T1 : tensor(uint8)
+- Constrain weights type to uint8 tensors.
+
+
+
### **com.microsoft.QOrderedAttention**
Quantized version of simplified Multi-Head Self Attention(using int8 with specific matrix Layout).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index bca8e17b3dfd4..c963781435465 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -868,6 +868,7 @@ Do not modify directly.*
|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
+|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)|
|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
index 2efc37cf98010..1dbbe8c4e7eaa 100644
--- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
@@ -18,23 +18,15 @@ namespace cuda {
#if defined(ORT_USE_NCCL)
-#define REGISTER_KERNEL_TYPED(T) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
- ShardedMoE, \
- kMSDomain, \
- 1, \
- T, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .MayInplace(0, 0) \
- .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ ShardedMoE, kMSDomain, 1, T, kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
ShardedMoE);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
-using namespace ONNX_NAMESPACE;
-
template
ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr("tensor_shards", &tensor_shards_).IsOK());
@@ -69,25 +61,23 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc3_experts_bias_optional = context->Input(7);
MoEParameters moe_params(tensor_shards_);
- ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
- fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
- fc3_experts_bias_optional));
+ MoEQuantType quant_type = MoEQuantType::None;
+ ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights,
+ fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional,
+ fc3_experts_weights_optional, fc3_experts_bias_optional));
- ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
- "num_experts should be divisible by world_size");
+ ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size");
if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
}
- ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm,
- fc3_experts_weights_optional != nullptr,
+ ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
- size_t ws_size =
- moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
- static_cast(moe_params.inter_size),
- static_cast(moe_params.num_experts), static_cast(k_));
+ size_t ws_size = moe_runner.getWorkspaceSize(
+ static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
+ static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_));
size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
@@ -107,30 +97,29 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
const CudaT* fc_scales_ptr = nullptr;
- moe_runner.run_moe_fc(reinterpret_cast(input->template Data()),
- reinterpret_cast(router_probs->template Data()),
- reinterpret_cast(fc1_experts_weights->template Data()),
- std::move(fc_scales_ptr),
- fc1_experts_bias_optional == nullptr
- ? nullptr
- : reinterpret_cast(fc1_experts_bias_optional->template Data()),
- activation_type_,
- fc3_experts_weights_optional == nullptr
- ? nullptr
- : reinterpret_cast(fc3_experts_weights_optional->template Data()),
- std::move(fc_scales_ptr),
- fc3_experts_bias_optional == nullptr
- ? nullptr
- : reinterpret_cast(fc3_experts_bias_optional->template Data()),
- reinterpret_cast(fc2_experts_weights->template Data()),
- std::move(fc_scales_ptr), static_cast(moe_params.num_rows),
- static_cast(moe_params.hidden_size),
- static_cast(moe_params.inter_size), static_cast(moe_params.num_experts),
- static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_),
- static_cast(k_), reinterpret_cast(work_space.get()),
- reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()),
- reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()),
- reinterpret_cast(expert_for_source_row.get()), Stream(context));
+ moe_runner.run_moe_fc(
+ reinterpret_cast(input->template Data()),
+ reinterpret_cast(router_probs->template Data()),
+ reinterpret_cast(fc1_experts_weights->template Data()), std::move(fc_scales_ptr),
+ fc1_experts_bias_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc1_experts_bias_optional->template Data()),
+ activation_type_,
+ fc3_experts_weights_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc3_experts_weights_optional->template Data()),
+ std::move(fc_scales_ptr),
+ fc3_experts_bias_optional == nullptr
+ ? nullptr
+ : reinterpret_cast(fc3_experts_bias_optional->template Data()),
+ reinterpret_cast(fc2_experts_weights->template Data()), std::move(fc_scales_ptr),
+ static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
+ static_cast(moe_params.inter_size), static_cast(moe_params.num_experts),
+ static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_),
+ static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()),
+ reinterpret_cast(expert_scales.get()),
+ reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()),
+ reinterpret_cast(expert_for_source_row.get()), Stream(context));
Tensor* output = context->Output(0, input->Shape());
@@ -146,12 +135,8 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size());
NCCL_RETURN_IF_ERROR(ncclGroupStart());
NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast(fc2_output.get()),
- reinterpret_cast(fc2_output_bc.get()),
- fc2_output_size / sizeof(CudaT),
- GetNcclDataType(input->DataType()),
- ncclSum,
- nccl_->Comm(),
- Stream(context)));
+ reinterpret_cast(fc2_output_bc.get()), fc2_output_size / sizeof(CudaT),
+ GetNcclDataType(input->DataType()), ncclSum, nccl_->Comm(), Stream(context)));
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}
@@ -166,19 +151,12 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
- moe_runner.get_total_rows_info(experts_start_index,
- moe_params.local_num_experts,
- total_past_rows,
+ moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes;
- NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
- dst,
- total_covered_rows * stride_count,
- GetNcclDataType(input->DataType()),
- rank,
- nccl_->Comm(),
- Stream(context)));
+ NCCL_RETURN_IF_ERROR(ncclBroadcast(src, dst, total_covered_rows * stride_count,
+ GetNcclDataType(input->DataType()), rank, nccl_->Comm(), Stream(context)));
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}
@@ -197,8 +175,7 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
}
template
-Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
- OpKernelContext* context,
+Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, OpKernelContext* context,
cudaEvent_t& cuda_event) const {
if (rank_to_experts_start_index_[0] != std::numeric_limits::min()) {
return Status::OK();
@@ -215,23 +192,16 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
IAllocator::MakeUniquePtr(allocator, nccl_->Size(), false, stream);
// Only happens in the first run.
- CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(),
- &local_experts_start_index_,
- IndexTypeSize,
- cudaMemcpyHostToDevice,
- Stream(context)));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), &local_experts_start_index_, IndexTypeSize,
+ cudaMemcpyHostToDevice, Stream(context)));
NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast(experts_start_index_d.get()),
- reinterpret_cast(rank_to_experts_start_index_d.get()),
- 1,
- GetNcclDataType(DataTypeImpl::GetType()),
- nccl_->Comm(),
+ reinterpret_cast(rank_to_experts_start_index_d.get()), 1,
+ GetNcclDataType(DataTypeImpl::GetType()), nccl_->Comm(),
Stream(context)));
// The const_cast<> violates the const modifier to make sure the synchronization happens only once per session.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(rank_to_experts_start_index_.data()),
- rank_to_experts_start_index_d.get(),
- nccl_->Size() * IndexTypeSize,
- cudaMemcpyDeviceToHost,
- Stream(context)));
+ rank_to_experts_start_index_d.get(), nccl_->Size() * IndexTypeSize,
+ cudaMemcpyDeviceToHost, Stream(context)));
CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming));
CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context)));
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 57e951d3a68ff..3621ffc5c64ca 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -72,6 +72,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
@@ -275,6 +276,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h
new file mode 100644
index 0000000000000..07c38c58e446a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h
@@ -0,0 +1,110 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Templates exposing architecture support for multiply-add operations
+*/
+
+#pragma once
+#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h"
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace arch {
+
+// Tag which triggers MMA which will trigger
+struct OpMultiplyAddDequantizeInterleavedBToA;
+
+/*
+ Below we have extra tags to signal what kind of dequantization we want to do
+ (per col, scale only fine grained, finegrained with zero). This still lets us
+ the existing template infrastructure (incl. that in CUTLASS). However, we
+ split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
+ with the quantization op before instantiating the GEMM pieces.
+
+ Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
+ code we need to duplicate.
+ */
+struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
+struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
+struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
+
+// The default just forwards the original operator
+template
+struct TagOperator {
+ using TaggedOperator = MmaOp;
+};
+
+// Specializations below attach more information to the operator
+template <>
+struct TagOperator {
+ using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
+};
+
+template <>
+struct TagOperator {
+ using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
+};
+
+template <>
+struct TagOperator {
+ using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
+};
+
+// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
+// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
+// as a default.
+template
+struct DetagOperator {
+ using Operator = TaggedMmaOp;
+ static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
+};
+
+template <>
+struct DetagOperator {
+ using Operator = OpMultiplyAddDequantizeInterleavedBToA;
+ static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
+};
+
+template <>
+struct DetagOperator {
+ using Operator = OpMultiplyAddDequantizeInterleavedBToA;
+ static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
+};
+
+template <>
+struct DetagOperator {
+ using Operator = OpMultiplyAddDequantizeInterleavedBToA;
+ static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
+};
+
+} // namespace arch
+} // namespace cutlass
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h
similarity index 62%
rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h
index 86136ea244e23..99cbe4a66049e 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h
@@ -26,19 +26,22 @@ namespace ort_fastertransformer {
template
inline int compute_occupancy_for_kernel() {
- int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
+ int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage));
if (smem_size > (48 << 10)) {
- cudaError_t status =
- cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
- if (status == cudaError::cudaErrorInvalidValue) {
- // Clear the error bit since we can ignore this.
- // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
- // occupancy of 0. This will cause the heuristic to ignore this configuration.
- status = cudaGetLastError();
+ cudaFuncAttributes attr;
+ int device = 0;
+ int max_smem_per_block = 0;
+ CUDA_CALL_THROW(cudaGetDevice(&device));
+ CUDA_CALL_THROW(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
+ CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel));
+ if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) {
+ // This should mean that
+ // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
+ // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
+ // configuration.
return 0;
}
- CUDA_CALL_THROW(status);
}
int max_active_blocks = -1;
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h
similarity index 57%
rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h
rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h
index 311ed323cb90c..da8cb6d294efd 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h
@@ -28,52 +28,68 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
-
/*! \file
- \brief Scheduler for grouped GEMM
+ \brief Functor performing linear combination with a maximum operation used by epilogues.
*/
#pragma once
+#include "cutlass/array.h"
#include "cutlass/cutlass.h"
-#include "cutlass/gemm/gemm.h"
-#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
-#include "cutlass/matrix_coord.h"
-
-#include "moe_problem_visitor.h"
+#include "cutlass/epilogue/thread/activation.h"
+#include "cutlass/epilogue/thread/linear_combination_generic.h"
+#include "cutlass/epilogue/thread/scale_type.h"
+#include "cutlass/functional.h"
+#include "cutlass/half.h"
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
-namespace gemm {
-namespace kernel {
+namespace epilogue {
+namespace thread {
-/// Visitor class to abstract away the algorithm for iterating over tiles
-template
-struct GemmMoeProblemVisitor
- : public MoeProblemVisitor, ThreadblockShape,
- GroupScheduleMode_, PrefetchTileCount, ThreadCount> {
- static bool const kTransposed = Transposed;
+/////////////////////////////////////////////////////////////////////////////////////////////////
- using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper;
- using Base =
- MoeProblemVisitor;
- using Params = typename Base::Params;
- using SharedStorage = typename Base::SharedStorage;
+__forceinline__ __device__ float copysignf_pos(float a, float b) {
+ float r;
+ r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
+ return r;
+}
- //
- // Methods
- //
- CUTLASS_DEVICE
- GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx)
- : Base(params_, shared_storage_, block_idx) {}
-};
+__forceinline__ __device__ float tanh_opt(float x) {
+#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
+ float const exp_val = -1.f * fabs(2 * x);
+ return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
+#else
+ return fast_tanh(x);
+#endif
+}
/////////////////////////////////////////////////////////////////////////////////////////////////
+template <>
+struct GELU_taylor {
+ static bool const kIsHeavy = true;
+
+ CUTLASS_DEVICE
+ float operator()(float const& z) const {
+ float k0 = static_cast(0.7978845608028654);
+ float k1 = static_cast(0.044715);
+
+ return static_cast(
+ cutlass::constants::half() * z *
+ (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z))));
+ }
+
+ using Params = LinearCombinationGenericParams;
+
+ CUTLASS_DEVICE
+ float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); }
+};
-} // namespace kernel
-} // namespace gemm
+} // namespace thread
+} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
new file mode 100644
index 0000000000000..affd1d83a35de
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
@@ -0,0 +1,306 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
+
+ original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
+
+*/
+
+#pragma once
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+#include "cutlass/arch/memory.h"
+#include "cutlass/arch/memory_sm75.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/fast_math.h"
+#include "cutlass/numeric_conversion.h"
+#include "tensorrt_llm/common/quantization.h"
+
+namespace tk = tensorrt_llm::common;
+
+namespace cutlass {
+namespace epilogue {
+namespace threadblock {
+
+template
+class EpilogueVisitorPerRowPerCol {
+ public:
+ using ThreadblockShape = ThreadblockShape_;
+ static int const kThreadCount = ThreadCount;
+
+ using ScaleTileIterator = ScaleTileIterator_;
+ using OutputTileIterator = OutputTileIterator_;
+ using ElementwiseFunctor = ElementwiseFunctor_;
+
+ static int const kIterations = OutputTileIterator::kIterations;
+ static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
+
+ using ElementOutput = typename OutputTileIterator::Element;
+ using LayoutOutput = cutlass::layout::RowMajor;
+ using ElementAccumulator = ElementAccumulator_;
+
+ using AlphaScaleElementType = typename ScaleTileIterator::Element;
+
+ using ElementCompute = ElementCompute_;
+ using AccumulatorFragment = Array;
+ using ComputeFragment = Array;
+ using OutputVector = Array;
+
+ static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
+ static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
+
+ /// Argument structure
+ struct Arguments {
+ typename ElementwiseFunctor::Params elementwise;
+ int64_t batch_stride_alpha;
+ int64_t batch_stride_C;
+ int64_t batch_stride_D;
+
+ //
+ // Methods
+ //
+ Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
+
+ explicit Arguments(typename ElementwiseFunctor::Params elementwise_)
+ : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
+
+ Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_,
+ int64_t batch_stride_D_)
+ : elementwise(elementwise_),
+ batch_stride_alpha(batch_stride_alpha_),
+ batch_stride_C(batch_stride_C_),
+ batch_stride_D(batch_stride_D_) {}
+ };
+
+ struct Params {
+ typename ElementwiseFunctor::Params elementwise;
+ int64_t batch_stride_alpha;
+ int64_t batch_stride_C;
+ int64_t batch_stride_D;
+
+ //
+ // Methods
+ //
+ CUTLASS_HOST_DEVICE
+ Params() {}
+
+ CUTLASS_HOST_DEVICE
+ explicit Params(Arguments const& args)
+ : elementwise(args.elementwise),
+ batch_stride_alpha(args.batch_stride_alpha),
+ batch_stride_C(args.batch_stride_C),
+ batch_stride_D(args.batch_stride_D) {}
+ };
+
+ /// Shared storage
+ struct SharedStorage {};
+
+ private:
+ Params const& params_;
+ SharedStorage& shared_storage_;
+ MatrixCoord extent_;
+ MatrixCoord extent_real_;
+ ElementwiseFunctor elementwise_;
+
+ bool const per_token_quant_;
+ bool const per_channel_quant_;
+
+ AlphaScaleElementType* ptr_alpha_row_;
+ AlphaScaleElementType* ptr_alpha_col_;
+ ScaleTileIterator iterator_alpha_col_;
+ OutputTileIterator iterator_C_;
+ OutputTileIterator iterator_D_;
+
+ AlphaScaleElementType element_alpha_row_ = 1.0f;
+ AlphaScaleElementType element_alpha_col_ = 1.0f;
+ typename ScaleTileIterator::Fragment fragment_alpha_col_;
+ typename OutputTileIterator::Fragment fragment_C_;
+ typename OutputTileIterator::Fragment fragment_D_;
+
+ ElementAccumulator beta_;
+
+ int column_offset_;
+
+ MatrixCoord thread_offset_;
+
+ public:
+ CUTLASS_DEVICE
+ EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
+ cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
+ typename ScaleTileIterator::Params params_alpha_col,
+ typename OutputTileIterator::Params params_C,
+ typename OutputTileIterator::Params params_D, tk::QuantMode quant_option,
+ AlphaScaleElementType* ptr_alpha_row, AlphaScaleElementType* ptr_alpha_col,
+ typename OutputTileIterator::Element* ptr_C, typename OutputTileIterator::Element* ptr_D,
+ cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0),
+ int column_offset = 0,
+ cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
+ : params_(params),
+ shared_storage_(shared_storage),
+ extent_(problem_size),
+ elementwise_(params.elementwise),
+ per_token_quant_(quant_option.hasPerTokenScaling()),
+ per_channel_quant_(quant_option.hasPerChannelScaling()),
+ ptr_alpha_row_(ptr_alpha_row),
+ ptr_alpha_col_(ptr_alpha_col),
+ iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset),
+ iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
+ iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
+ extent_real_(problem_size_real) {
+ beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
+
+ if (beta_ == ElementAccumulator()) {
+ iterator_C_.clear_mask();
+ }
+
+ if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) {
+ element_alpha_col_ = *ptr_alpha_col_;
+ }
+
+ if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) {
+ element_alpha_row_ = *ptr_alpha_row_;
+ }
+ }
+
+ /// Helper to indicate split-K behavior
+ CUTLASS_DEVICE
+ void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
+ int split_k_slices) { ///< Total number of split-K slices
+ }
+
+ /// Called to set the batch index
+ CUTLASS_DEVICE
+ void set_batch_index(int batch_idx) {
+ iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
+ iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
+ iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
+ }
+
+ /// Called at the start of the epilogue just before iterating over accumulator slices
+ CUTLASS_DEVICE
+ void begin_epilogue() {
+ if (per_channel_quant_) {
+ iterator_alpha_col_.load(fragment_alpha_col_);
+ }
+ }
+
+ /// Called at the start of one step before starting accumulator exchange
+ CUTLASS_DEVICE
+ void begin_step(int step_idx) {
+ fragment_D_.clear();
+ fragment_C_.clear();
+
+ if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
+ iterator_C_.load(fragment_C_);
+ ++iterator_C_;
+ }
+ }
+
+ /// Called at the start of a row
+ CUTLASS_DEVICE
+ void begin_row(int row_idx) {
+ // load alpha_row in begin_step only when per token(row) scaling is used
+ if (per_token_quant_) {
+ int thread_offset_row =
+ iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
+
+ arch::global_load(
+ element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
+ }
+ }
+
+ /// Called after accumulators have been exchanged for each accumulator vector
+ CUTLASS_DEVICE
+ void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) {
+ NumericArrayConverter source_converter;
+
+ ComputeFragment result = source_converter(accum);
+ if (per_channel_quant_) {
+ ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx];
+ result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
+ } else {
+ result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
+ }
+
+ // Convert to the output
+ NumericArrayConverter output_converter;
+ OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx];
+ output = output_converter(result);
+ }
+
+ /// Called at the end of a row
+ CUTLASS_DEVICE
+ void end_row(int row_idx) {}
+
+ /// Called after all accumulator elements have been visited
+ CUTLASS_DEVICE
+ void end_step(int step_idx) {
+ iterator_D_.store(fragment_D_);
+ ++iterator_D_;
+ }
+
+ /// Called after all steps have been completed
+ CUTLASS_DEVICE
+ void end_epilogue() {}
+
+ private:
+ CUTLASS_DEVICE
+ ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col,
+ AlphaScaleElementType const& scale_row) {
+ ComputeFragment result;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ComputeFragment::kElements; ++i) {
+ result[i] = accum[i] * (scale_col[i] * scale_row);
+ }
+
+ return result;
+ }
+
+ CUTLASS_DEVICE
+ ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col,
+ AlphaScaleElementType const& scale_row) {
+ ComputeFragment result;
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < ComputeFragment::kElements; ++i) {
+ result[i] = accum[i] * (scale_col * scale_row);
+ }
+
+ return result;
+ }
+};
+
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
new file mode 100644
index 0000000000000..40f126d56616a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
@@ -0,0 +1,247 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
+
+ The epilogue rearranges the result of a matrix product through shared memory to match canonical
+ tensor layouts in global memory. Epilogues support conversion and reduction operations.
+
+ original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
+
+*/
+
+#pragma once
+
+#include "cutlass/array.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/platform/platform.h"
+
+#include "cutlass/gemm/gemm.h"
+
+#include "cutlass/epilogue/thread/linear_combination.h"
+#include "cutlass/epilogue/thread/linear_combination_clamp.h"
+#include "cutlass/epilogue/thread/linear_combination_gelu.h"
+#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
+#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
+#include "cutlass/epilogue/thread/linear_combination_relu.h"
+#include "cutlass/epilogue/thread/linear_combination_relu0.h"
+#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
+
+#include "cutlass/epilogue/thread/conversion_op.h"
+#include "cutlass/epilogue/thread/reduction_op.h"
+
+#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
+
+#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
+#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
+#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
+#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
+#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
+#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
+#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
+#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
+
+#include "cutlass/epilogue/threadblock/epilogue.h"
+#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
+
+#include "cutlass/layout/permute.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace epilogue {
+namespace threadblock {
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace detail {
+
+/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
+template
+struct DefaultIteratorsTensorOp {
+ using WarpTileIterator =
+ cutlass::epilogue::warp::TileIteratorTensorOpMixed;
+
+ using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed;
+
+ static int const kFragmentsPerIteration = 2;
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace detail
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Tile iterator used to load output tile from shared memory in epilogue.
+///
+/// Satisfies: ReadableTileIterator
+///
+template
+class SharedLoadIteratorMixed {
+ public:
+ using ThreadMap = ThreadMap_;
+ using Shape = typename ThreadMap::Shape;
+
+ using Element = int32_t;
+
+ using Layout = layout::RowMajor;
+ using TensorRef = TensorRef;
+ using ConstTensorRef = typename TensorRef::ConstTensorRef;
+
+ using Index = typename Layout::Index;
+ using LongIndex = typename Layout::LongIndex;
+ using TensorCoord = MatrixCoord;
+
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
+
+ static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8;
+
+ static int const kThreads = ThreadMap::kThreads;
+
+ /// Fragment object
+ using Fragment =
+ Array;
+
+ /// Memory access size
+ using AccessType = AlignedArray;
+
+ /// Vector type used for SMEM loads
+ using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess),
+ const_min(16, kAlignment)>;
+
+ static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
+
+ private:
+ //
+ // Data members
+ //
+
+ /// Byte-level pointer
+ LoadType const* pointers_[kLoadsPerAccess];
+
+ /// Stride along adjacent rows in units of LoadType
+ int stride_;
+
+ public:
+ //
+ // Methods
+ //
+
+ /// Constructor
+ CUTLASS_DEVICE
+ SharedLoadIteratorMixed(TensorRef ref, int thread_idx) : stride_((ref.stride(0) / LoadType::kElements)) {
+ TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
+
+ // Initialize pointers
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kLoadsPerAccess; ++i) {
+ pointers_[i] = reinterpret_cast(ref.data());
+
+ int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
+ int bank_offset = (col_idx * static_cast(sizeof(LoadType)) / 128) % kLoadsPerAccess;
+
+ col_idx += (bank_offset + i) % kLoadsPerAccess;
+
+ pointers_[i] += thread_offset.row() * stride_ + col_idx;
+ }
+ }
+
+ /// Adds a pointer offset in units of Element
+ CUTLASS_HOST_DEVICE
+ void add_pointer_offset(LongIndex pointer_offset) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kLoadsPerAccess; ++i) {
+ pointers_[i] += pointer_offset / LoadType::kElements;
+ }
+ }
+
+ CUTLASS_DEVICE
+ void add_tile_offset(TensorCoord const& offset) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int i = 0; i < kLoadsPerAccess; ++i) {
+ pointers_[i] += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
+ }
+ }
+
+ /// Loads a fragment from memory
+ CUTLASS_DEVICE
+ void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const {
+ CUTLASS_PRAGMA_UNROLL
+ for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
+ CUTLASS_PRAGMA_UNROLL
+ for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
+ int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ +
+ cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements;
+
+ int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
+
+ LoadType* frag_ptr = reinterpret_cast(&frag);
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
+ int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
+
+ CUTLASS_PRAGMA_UNROLL
+ for (int v = 0; v < kLoadsPerAccess; ++v) {
+ int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
+
+ LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
+
+ frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /// Loads a fragment
+ CUTLASS_DEVICE
+ void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace threadblock
+} // namespace epilogue
+} // namespace cutlass
+
+////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h
similarity index 57%
rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h
index b18a70e899d1c..b784646c31f84 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h
@@ -1,11 +1,12 @@
/*
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -24,139 +25,85 @@
#pragma once
-#include "cutlass/array.h"
-#include "cutlass/cutlass.h"
-#include "cutlass/epilogue/thread/activation.h"
-#include "cutlass/epilogue/thread/scale_type.h"
-#include "cutlass/functional.h"
-#include "cutlass/half.h"
-#include "cutlass/numeric_conversion.h"
-#include "cutlass/numeric_types.h"
+#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
-namespace cutlass {
-namespace epilogue {
-namespace thread {
-
-__forceinline__ __device__ float copysignf_pos(float a, float b) {
- float r;
- r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
- return r;
-}
-
-__forceinline__ __device__ float tanh_opt(float x) {
-#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
- const float exp_val = -1.f * fabs(2 * x);
- return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
-#else
- return fast_tanh(x);
-#endif
-}
-
-template <>
-struct GELU_taylor {
- static const bool kIsHeavy = true;
- CUTLASS_DEVICE
- float operator()(float const& z) const {
- float k0 = float(0.7978845608028654);
- float k1 = float(0.044715);
-
- return float(
- cutlass::constants::half() * z *
- (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z))));
- }
-
- using Params = LinearCombinationGenericParams;
-
- CUTLASS_DEVICE
- float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); }
-};
-
-} // namespace thread
-} // namespace epilogue
-} // namespace cutlass
-
namespace ort_fastertransformer {
struct EpilogueOpBiasSilu {};
-struct EpilogueOpNoBiasSilu {};
-
struct EpilogueOpBiasReLU {};
-struct EpilogueOpNoBiasReLU {};
-
struct EpilogueOpBiasFtGelu {};
-struct EpilogueOpNoBiasFtGelu {};
+struct EpilogueOpDefaultSilu {};
+
+struct EpilogueOpDefaultReLU {};
+
+struct EpilogueOpDefaultFtGelu {};
struct EpilogueOpBias {};
-struct EpilogueOpNoBias {};
+struct EpilogueOpDefault {};
template
struct Epilogue {};
+constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
+
template
struct Epilogue {
using Op = cutlass::epilogue::thread::LinearCombinationSilu;
+ ElementAccumulator, BiasScaleMode>;
};
template
-struct Epilogue {
- using Op = cutlass::epilogue::thread::LinearCombinationSilu;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationRelu;
};
template
-struct Epilogue {
- using Op = cutlass::epilogue::thread::LinearCombinationRelu;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
+ cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
+ ElementAccumulator, BiasScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template
-struct Epilogue {
- using Op = cutlass::epilogue::thread::LinearCombinationRelu;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombination;
};
+constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
+
template
-struct Epilogue {
- using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
- cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
- ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling,
- cutlass::FloatRoundStyle::round_to_nearest, true>;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationSilu;
};
template
-struct Epilogue {
- using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
- cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
- ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling,
- cutlass::FloatRoundStyle::round_to_nearest, true>;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationRelu;
};
template
-struct Epilogue {
- using Op = cutlass::epilogue::thread::LinearCombination;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
+ cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator,
+ ElementAccumulator, DefaultScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>;
};
template
-struct Epilogue {
- using Op =
- cutlass::epilogue::thread::LinearCombination<
- ElementType, ElementsPerVectorAccess, ElementAccumulator,
- ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
+struct Epilogue {
+ using Op = cutlass::epilogue::thread::LinearCombination;
};
} // namespace ort_fastertransformer
diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
new file mode 100644
index 0000000000000..f5064afc23ae0
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
@@ -0,0 +1,384 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*!
+ \file
+ \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
+ batched array variants.
+*/
+
+#pragma once
+
+// #include
+#include
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/device_kernel.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/kernel/gemm_universal.h"
+#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
+
+#include "cutlass/gemm/device/default_gemm_configuration.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+
+#include "cutlass/trace.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace device {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/*
+ This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
+ It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
+ and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
+
+ Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
+ that feature at the moment.
+ */
+
+template
+class GemmUniversalBaseCompat {
+ public:
+ using GemmKernel = GemmKernel_;
+ using ThreadblockShape = typename GemmKernel::Mma::Shape;
+
+ using ElementA = typename GemmKernel::ElementA;
+ using LayoutA = typename GemmKernel::LayoutA;
+ using TensorRefA = TensorRef;
+ static ComplexTransform const kTransformA = GemmKernel::kTransformA;
+
+ using ElementB = typename GemmKernel::ElementB;
+ using LayoutB = typename GemmKernel::LayoutB;
+ using TensorRefB = TensorRef;
+ static ComplexTransform const kTransformB = GemmKernel::kTransformB;
+
+ using ElementC = typename GemmKernel::ElementC;
+ using LayoutC = typename GemmKernel::LayoutC;
+ using TensorRefC = TensorRef;
+ using TensorRefD = TensorRef;
+
+ using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
+
+ using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
+ using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
+ using Operator = typename GemmKernel::Operator;
+
+ /// Argument structure
+ using Arguments = typename GemmKernel::Arguments;
+
+ protected:
+ /// Kernel parameters object
+ typename GemmKernel::Params params_;
+
+ protected:
+ /// Private helper to obtain the grid dimensions with fix-up for split-K
+ static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) {
+ // Determine grid shape
+ ThreadblockSwizzle threadblock_swizzle;
+
+ grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
+ args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
+
+ gemm_k_size = args.problem_size.k();
+
+ if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) {
+ int const kAlignK =
+ const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1);
+
+ gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
+
+ if (gemm_k_size) {
+ grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
+ }
+ }
+ }
+
+ public:
+ /// Constructs the GEMM.
+ GemmUniversalBaseCompat() {}
+
+ /// Determines whether the GEMM can execute the given problem.
+ static Status can_implement(Arguments const& args) {
+ // Determine grid shape
+ cutlass::gemm::GemmCoord grid_tiled_shape;
+ int gemm_k_size = 0;
+
+ get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
+
+ ThreadblockSwizzle threadblock_swizzle;
+ dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
+
+ uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
+
+ if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) {
+ return Status::kErrorInvalidProblem;
+ }
+
+ return GemmKernel::can_implement(args);
+ }
+
+ /// Gets the workspace size
+ static size_t get_workspace_size(Arguments const& args) {
+ CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
+
+ size_t workspace_bytes = 0;
+
+ // Determine grid shape
+ cutlass::gemm::GemmCoord grid_tiled_shape;
+ int gemm_k_size = 0;
+
+ get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
+
+ if (args.mode == GemmUniversalMode::kGemmSplitKParallel) {
+ // Split-K parallel always requires a temporary workspace
+ workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
+ } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) {
+ // Serial split-K only requires a temporary workspace if the number of partitions along the
+ // GEMM K dimension is greater than one.
+ workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
+ }
+
+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
+
+ workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
+
+ return workspace_bytes;
+ }
+
+ /// Computes the grid shape
+ static dim3 get_grid_shape(Arguments const& args) {
+ CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
+
+ ThreadblockSwizzle threadblock_swizzle;
+
+ cutlass::gemm::GemmCoord grid_tiled_shape;
+ int gemm_k_size = 0;
+
+ get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
+ dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
+
+ CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
+ << " result = {" << result << "}");
+
+ return result;
+ }
+
+ /// Computes the maximum number of active blocks per multiprocessor
+ static int maximum_active_blocks(int smem_capacity = -1) {
+ CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
+
+ int max_active_blocks = -1;
+ int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage));
+
+ CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
+
+ if (smem_size <= (48 << 10)) {
+ cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel,
+ GemmKernel::kThreadCount, smem_size);
+
+ if (result == cudaSuccess) {
+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
+ return max_active_blocks;
+ }
+ } else {
+ // Query assuming zero shared memory then compute occupancy limit based on SMEM
+ cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel,
+ GemmKernel::kThreadCount, 0);
+
+ if (result != cudaSuccess) {
+ CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
+ << cudaGetErrorString(result));
+
+ return -1;
+ }
+
+ if (smem_capacity < 0) {
+ int device_idx = 0;
+ result = cudaGetDevice(&device_idx);
+
+ if (result != cudaSuccess) {
+ return -1;
+ }
+
+ cudaDeviceProp properties;
+ result = cudaGetDeviceProperties(&properties, device_idx);
+
+ if (result != cudaSuccess) {
+ return -1;
+ }
+
+ smem_capacity = static_cast(properties.sharedMemPerMultiprocessor);
+ }
+
+ int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
+
+ CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
+
+ return occupancy;
+ }
+
+ CUTLASS_TRACE_HOST(" returning internal error");
+
+ return -1;
+ }
+
+ /// Initializes GEMM state from arguments.
+ Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
+ CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
+ << workspace << ", stream: " << (stream ? "non-null" : "null"));
+
+ size_t workspace_bytes = get_workspace_size(args);
+
+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
+
+ if (workspace_bytes) {
+ if (!workspace) {
+ CUTLASS_TRACE_HOST(" error: device workspace must not be null");
+
+ return Status::kErrorWorkspaceNull;
+ }
+
+ if (args.mode == GemmUniversalMode::kGemm) {
+ CUTLASS_TRACE_HOST(" clearing device workspace");
+ cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
+
+ if (result != cudaSuccess) {
+ CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
+
+ return Status::kErrorInternal;
+ }
+ }
+ }
+
+ // Get CUDA grid shape
+ cutlass::gemm::GemmCoord grid_tiled_shape;
+ int gemm_k_size = 0;
+
+ get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
+
+ // Initialize the Params structure
+ params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace));
+
+ // Specify shared memory capacity for kernel.
+ int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage));
+
+ if (smem_size >= (48 << 10)) {
+ cudaError_t result =
+ cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+
+ if (result != cudaSuccess) {
+ return Status::kErrorInternal;
+ }
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Lightweight update given a subset of arguments
+ Status update(Arguments const& args, void* workspace = nullptr) {
+ CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
+
+ size_t workspace_bytes = get_workspace_size(args);
+
+ if (workspace_bytes && !workspace) {
+ return Status::kErrorWorkspaceNull;
+ }
+
+ params_.update(args, workspace);
+
+ return Status::kSuccess;
+ }
+
+ /// Runs the kernel using initialized state.
+ Status run(cudaStream_t stream = nullptr) {
+ CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
+
+ //
+ // Configure grid and block dimensions
+ //
+
+ ThreadblockSwizzle threadblock_swizzle;
+
+ dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
+ dim3 block(GemmKernel::kThreadCount, 1, 1);
+
+ int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage));
+
+ //
+ // Launch kernel
+ //
+
+ CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
+
+ // Launch
+ cutlass::Kernel<<>>(params_);
+
+ //
+ // Query for errors
+ //
+ cudaError_t result = cudaGetLastError();
+
+ if (result != cudaSuccess) {
+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
+ return Status::kErrorInternal;
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Runs the kernel using initialized state.
+ Status operator()(cudaStream_t stream = nullptr) { return run(stream); }
+
+ /// Runs the kernel using initialized state.
+ Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
+ Status status = initialize(args, workspace, stream);
+
+ if (status == Status::kSuccess) {
+ status = run(stream);
+ }
+
+ return status;
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace device
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h
new file mode 100644
index 0000000000000..b226b73e86fe1
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h
@@ -0,0 +1,476 @@
+/***************************************************************************************************
+ * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * 3. Neither the name of the copyright holder nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*!
+ \file
+ \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
+*/
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#include "cutlass/arch/arch.h"
+#include "cutlass/cutlass.h"
+#include "cutlass/device_kernel.h"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/kernel/gemm_universal.h"
+#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
+
+#include "cutlass/gemm/device/default_gemm_configuration.h"
+#include "cutlass/gemm/kernel/default_gemm_universal.h"
+
+#include "cutlass/trace.h"
+
+////////////////////////////////////////////////////////////////////////////////
+
+namespace cutlass {
+namespace gemm {
+namespace device {
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
+ int64_t* splitk_buffer_offsets) {
+ // in_tensor: [problem_idx, k_partition, hidden_size]
+ // Note that different requests of in_tensor might have different hidden_size (=m*n)
+ // so, we need to use splitk_buffer_offsets.
+ // out_tensor: problem_idx * [hidden_size]
+
+ int const problem_idx = blockIdx.y;
+ GemmCoord problem = problem_sizes[problem_idx];
+ int const hidden_size = problem.m() * problem.n();
+ const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
+ T_OUT* out_tensor_ = out_tensor[problem_idx];
+
+ for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) {
+ float sum = 0.0f;
+ for (int k_idx = 0; k_idx < splitk; k_idx++) {
+ sum += static_cast(in_tensor_[k_idx * hidden_size + i]);
+ }
+ out_tensor_[i] = (T_OUT)(sum);
+ }
+}
+
+/// GEMM Grouped
+template
+class BaseSplitkGrouped {
+ public:
+ using BaseKernel = BaseKernel_;
+
+ using ElementA = typename BaseKernel::ElementA;
+ using LayoutA = typename BaseKernel::LayoutA;
+ using TensorRefA = TensorRef;
+ static ComplexTransform const kTransformA = BaseKernel::kTransformA;
+ static int const kAlignmentA = BaseKernel::kAlignmentA;
+
+ using ElementB = typename BaseKernel::ElementB;
+ using LayoutB = typename BaseKernel::LayoutB;
+ using TensorRefB = TensorRef;
+ static ComplexTransform const kTransformB = BaseKernel::kTransformB;
+ static int const kAlignmentB = BaseKernel::kAlignmentB;
+
+ using ElementC = typename BaseKernel::ElementC;
+ using LayoutC = typename BaseKernel::LayoutC;
+ using TensorRefC = TensorRef;
+ using TensorRefD = TensorRef;
+ static int const kAlignmentC = BaseKernel::kAlignmentC;
+
+ using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
+
+ using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
+ using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
+
+ using Operator = typename BaseKernel::Operator;
+ using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
+
+ using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
+ using MathOperator = typename WarpMmaOperator::MathOperator;
+ using OperatorClass = typename WarpMmaOperator::OperatorClass;
+ using ArchTag = typename WarpMmaOperator::ArchTag;
+ using ThreadblockShape = typename BaseKernel::Mma::Shape;
+ using WarpShape = typename BaseKernel::WarpShape;
+ using InstructionShape = typename BaseKernel::InstructionShape;
+ static int const kStages = BaseKernel::Mma::kStages;
+
+ /// Argument structure
+ using Arguments = typename BaseKernel::Arguments;
+
+ using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
+
+ protected:
+ /// Kernel parameters object
+ typename BaseKernel::Params gemm_params_;
+
+ private:
+ /// Get the number of tiles across all problems in a group
+ static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) {
+ int32_t tiles = 0;
+ for (int32_t i = 0; i < problem_count; ++i) {
+ cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
+ BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
+ tiles += problem_tile_count(problem);
+ }
+ return tiles;
+ }
+
+ /// Copy from `data` to `workspace`
+ Status copy_to_workspace(void* workspace, void* data, size_t bytes) {
+ cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
+ if (cuda_error != cudaSuccess) {
+ // Call cudaGetLastError() to clear the error bit
+ cuda_error = cudaGetLastError();
+ CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
+ return Status::kErrorInternal;
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Precomputes scheduling information for the grouped GEMM
+ Status precompute(Arguments const& args, int32_t tile_count, void* workspace) {
+ size_t workspace_bytes = get_workspace_size(args);
+ std::vector host_workspace(workspace_bytes);
+ BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, args.problem_count, args.threadblock_count,
+ reinterpret_cast(host_workspace.data()));
+ return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
+ }
+
+ /// Reorder `data` according to `indices`
+ template
+ static void reorder_array(T* data, std::vector const& indices) {
+ // For now, simply create a copy of the data and then copy over to the original.
+ std::vector copy(indices.size());
+ for (size_t i = 0; i < indices.size(); ++i) {
+ copy.at(i) = data[indices[i]];
+ }
+
+ memcpy(data, copy.data(), indices.size() * sizeof(T));
+ }
+
+ public:
+ /// Constructs the GEMM.
+ BaseSplitkGrouped() {}
+
+ /// Determines whether the GEMM can execute the given problem.
+ static Status can_implement(Arguments const& args) { return BaseKernel::can_implement(args); }
+
+ /// Get the number of tiles in a problem
+ static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) {
+ auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
+ return BaseKernel::ProblemVisitor::tile_count(grid);
+ }
+
+ /// Get the number of tiles across all problems in a group
+ static int32_t group_tile_count(Arguments const& args) {
+ if (args.host_problem_sizes == nullptr) {
+ CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
+ return -1;
+ }
+
+ return group_tile_count(args.host_problem_sizes, args.problem_count);
+ }
+
+ /// Gets the workspace size
+ static size_t get_workspace_size(Arguments const& args) {
+ size_t total_mn = 0;
+ for (int i = 0; i < args.problem_count; i++) {
+ total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
+ }
+ size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
+
+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
+ workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, args.problem_count,
+ args.threadblock_count);
+ }
+ return workSpaceSize;
+ }
+
+ /// Computes the grid shape
+ static dim3 get_grid_shape(Arguments const& args) { return dim3(args.threadblock_count, 1, 1); }
+
+ /// Computes the maximum number of active blocks per multiprocessor
+ static int maximum_active_blocks(int smem_capacity = -1) {
+ CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
+
+ int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage));
+
+ CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
+
+ cudaError_t result;
+ if (smem_size > (48 << 10)) {
+ result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+
+ if (result != cudaSuccess) {
+ // Call cudaGetLastError() to clear the error bit
+ result = cudaGetLastError();
+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
+ return -1;
+ }
+ }
+
+ int max_active_blocks = -1;
+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel,
+ BaseKernel::kThreadCount, smem_size);
+
+ if (result != cudaSuccess) {
+ // Call cudaGetLastError() to clear the error bit
+ result = cudaGetLastError();
+ CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
+ << cudaGetErrorString(result));
+ return -1;
+ }
+
+ CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
+ return max_active_blocks;
+ }
+
+ /// Sorts each pointer passed in according to the indices that sort
+ /// `problem_sizes_ptr` in descending order of problem-K dimension.
+ static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
+ int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
+ int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) {
+ std::vector indices(problem_count);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_ptr](size_t i, size_t j) {
+ return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k();
+ });
+
+ reorder_array(problem_sizes_ptr, indices);
+ reorder_array(lda_host_ptr, indices);
+ reorder_array(ldb_host_ptr, indices);
+ reorder_array(ldc_host_ptr, indices);
+ reorder_array(ldd_host_ptr, indices);
+ reorder_array(offset_A_ptr, indices);
+ reorder_array(offset_B_ptr, indices);
+ reorder_array(offset_C_ptr, indices);
+ reorder_array(offset_D_ptr, indices);
+ }
+
+ /// Computes the number of threadblocks to launch for the grouped kernel
+ static int sufficient(cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0,
+ int available_sm_count = -1) {
+ // Determine the number of blocks that would be launched to fill up a single
+ // wave on the GPU with each SM having maximum occupancy.
+ int device_idx;
+ cudaError_t result = cudaGetDevice(&device_idx);
+ if (result != cudaSuccess) {
+ // Call cudaGetLastError() to clear the error bit
+ result = cudaGetLastError();
+ CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
+ return 0;
+ }
+
+ int multiprocessor_count;
+ result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
+ if (result != cudaSuccess) {
+ CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
+ return 0;
+ }
+
+ bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
+ if (override_sm_count) {
+ available_sm_count = multiprocessor_count;
+ }
+
+ int max_active_blocks = maximum_active_blocks();
+ if (max_active_blocks <= 0) {
+ return 0;
+ }
+
+ int occupancy_based_block_count = available_sm_count * max_active_blocks;
+
+ if (problem_sizes_ptr == nullptr || problem_count == 0) {
+ return occupancy_based_block_count;
+ }
+
+ int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
+
+ // If the group contains a single problem, launching the exact number of
+ // threadblocks needed to cover the problem minimizes the work performed
+ // per threadblock in finding the next tile to compute. We return total_tiles
+ // unless the user has provided the SM count.
+ if (problem_count == 1 && override_sm_count) {
+ return total_tiles;
+ }
+
+ // Choose between the full wave of threadblocks and the tile count. If there
+ // are fewer tiles in the group than threadblocks in the full wave, only
+ // some threadblocks will be assigned tiles. Those threadblocks
+ // which are not assigned tiles still need to perform the work of iterating through
+ // problem sizes to determine that they have no work to do. This competes for cycles
+ // with those threadblocks that are assigned tiles to compute.
+ return std::min(total_tiles, occupancy_based_block_count);
+ }
+
+ /// Initializes GEMM state from arguments.
+ Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
+ CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
+ << workspace << ", stream: " << (stream ? "non-null" : "null"));
+
+ // Workspace
+ size_t workspace_bytes = get_workspace_size(args);
+
+ if (workspace_bytes && !workspace) {
+ return Status::kErrorWorkspaceNull;
+ }
+
+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
+ int32_t tile_count = group_tile_count(args);
+ Status status = precompute(args, tile_count, workspace);
+ if (status != Status::kSuccess) {
+ return status;
+ }
+
+ gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
+ } else {
+ gemm_params_ = typename BaseKernel::Params(args, workspace);
+ }
+
+ // Specify shared memory capacity for kernel.
+ int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage));
+
+ if (smem_size >= (48 << 10)) {
+ cudaError_t result =
+ cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
+
+ if (result != cudaSuccess) {
+ return Status::kErrorInternal;
+ }
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Lightweight update given a subset of arguments
+ Status update(Arguments const& args, void* workspace = nullptr) {
+ size_t workspace_bytes = get_workspace_size(args);
+
+ if (workspace_bytes && !workspace) {
+ return Status::kErrorWorkspaceNull;
+ }
+
+ if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) {
+ int32_t tile_count = group_tile_count(args);
+ Status status = precompute(args, tile_count, workspace);
+ if (status != Status::kSuccess) {
+ return status;
+ }
+
+ gemm_params_.update(args, workspace, tile_count);
+ } else {
+ gemm_params_.update(args, workspace);
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Runs the kernel using initialized state.
+ Status run(cudaStream_t stream = nullptr) {
+ if (!gemm_params_.problem_visitor.problem_count) {
+ return Status::kSuccess;
+ }
+
+ //
+ // Launch kernel
+ //
+
+ // Launch splitk grouped gemm
+ {
+ dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
+ dim3 block(BaseKernel::kThreadCount, 1, 1);
+
+ int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage));
+ cutlass::Kernel<<>>(gemm_params_);
+
+ cudaError_t result = cudaGetLastError();
+
+ if (result != cudaSuccess) {
+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
+ return Status::kErrorInternal;
+ }
+ }
+
+ // Launch splitkReduction
+ {
+ dim3 grid(32, gemm_params_.problem_visitor.problem_count);
+ dim3 block(256);
+ splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
+ gemm_params_.problem_visitor.problem_sizes,
+ gemm_params_.split_k_slices, gemm_params_.splitk_buffer_offsets);
+
+ cudaError_t result = cudaGetLastError();
+
+ if (result != cudaSuccess) {
+ CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
+ return Status::kErrorInternal;
+ }
+ }
+
+ return Status::kSuccess;
+ }
+
+ /// Runs the kernel using initialized state.
+ Status operator()(cudaStream_t stream = nullptr) { return run(stream); }
+
+ /// Initializes and runs the kernel.
+ Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) {
+ Status status = initialize(args, workspace, stream);
+
+ if (status == Status::kSuccess) {
+ status = run(stream);
+ }
+
+ return status;
+ }
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// GEMM Grouped
+template
+class SplitkGemmGrouped : public BaseSplitkGrouped {
+ public:
+ using GemmKernel = GemmKernel_;
+};
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace device
+} // namespace gemm
+} // namespace cutlass
+
+/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
similarity index 71%
rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h
rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
index eb33a98e4246f..2b3478a38fc2e 100644
--- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h
+++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
@@ -1,11 +1,12 @@
/*
- * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,51 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-/*
- This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
- quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
- to be consumed by CUTLASS.
-
- Note that for int4, ThreadBlockK MUST be 64.
-
- */
-
#pragma once
-#include "cutlass/layout/matrix.h"
-#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
-#include "cutlass/platform/platform.h"
+#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
+#include "cutlass/layout/matrix.h"
+
+#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h"
+#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace cutlass {
namespace gemm {
namespace kernel {
-template
-struct LayoutDetailsB {};
-
-// Volta specialiations. Volta will dequantize before STS, so we need a different operator
-template
-struct LayoutDetailsB {
- static constexpr int ThreadblockK = 64;
- using Layout = layout::RowMajor;
- static constexpr int ElementsPerAccess = 8;
- using Operator = cutlass::arch::OpMultiplyAdd;
-};
-
-// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
-// TODO - Switch this to column major for weights since gemms should be more performant.
-template
-struct LayoutDetailsB= 75>::type> {
- static constexpr int ThreadblockK = 64;
- using Layout = layout::RowMajor;
- static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value;
- using Operator = cutlass::arch::OpMultiplyAdd;
-};
-
template
struct MixedGemmArchTraits {};
@@ -66,7 +38,7 @@ struct MixedGemmArchTraits {
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
- using LayoutB = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
@@ -80,10 +52,13 @@ struct MixedGemmArchTraits {
// ========================= Volta Traits ===========================
// Volta will always dequantize after the global memory load.
// This will instantiate any HMMA tensorcore kernels for Volta.
+// Note that volta does not have native bfloat support so weights and activations will be casted to fp16
+// and compute will happen in fp16 then will be converted for bf16 output.
template
struct MixedGemmArchTraits<
TypeA, TypeB, cutlass::arch::Sm70,
- typename cutlass::platform::enable_if::value>::type> {
+ typename cutlass::platform::enable_if::value ||
+ cutlass::platform::is_same::value>::type> {
private:
using LayoutDetails = LayoutDetailsB;
@@ -103,10 +78,13 @@ struct MixedGemmArchTraits<
};
// ======================= Turing Traits ==============================
+// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
+// and compute will happen in fp16 then will be converted for bf16 output.
template
struct MixedGemmArchTraits<
TypeA, TypeB, cutlass::arch::Sm75,
- typename cutlass::platform::enable_if::value>::type> {
+ typename cutlass::platform::enable_if::value ||
+ cutlass::platform::is_same::value>::type> {
private:
using LayoutDetails = LayoutDetailsB