From 17919717b57c1246a9f2629b7d3e9d523d80dcd6 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:24:19 -0700 Subject: [PATCH] add QMoE (#20108) ### Description 1. Introduce latest cutlass extension from TRTLLM that gives us cutlass upgrade(to 3.4) opportunity from MoE side. 2. Fix Windows build issue 3. Add Int4 MoE op and ut ### Motivation and Context --- cmake/onnxruntime_rocm_hipify.cmake | 2 + docs/ContribOperators.md | 64 + docs/OperatorKernels.md | 1 + .../cuda/collective/sharded_moe.cc | 126 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../cuda/moe/cutlass_extensions/arch/mma.h | 110 ++ .../compute_occupancy.h | 21 +- .../epilogue/thread/fused_activations.h} | 76 +- .../epilogue_per_row_per_col_scale.h | 306 ++++ .../threadblock/epilogue_tensor_op_int32.h | 247 +++ .../epilogue_helpers.h | 129 +- .../gemm/device/gemm_universal_base_compat.h | 384 +++++ .../gemm/device/splitk_gemm_grouped.h | 476 ++++++ .../gemm/kernel/default_fpA_intB_traits.h} | 61 +- .../gemm/kernel/default_int8_traits.h | 51 + .../gemm/kernel/default_splitk_gemm_grouped.h | 206 +++ .../gemm/kernel/fpA_intB_gemm.h | 513 ++++++ .../gemm/kernel/gemm_moe_problem_visitor.h | 66 + .../gemm/kernel/gemm_with_epilogue_visitor.h | 516 ++++++ .../gemm/kernel/mixed_gemm_B_layout.h | 126 ++ .../gemm/kernel/moe_cutlass_kernel.h | 471 +++++ .../gemm/kernel}/moe_problem_visitor.h | 54 +- .../gemm/kernel/splitk_gemm_grouped.h | 464 +++++ .../gemm/threadblock/default_dq_mma.h | 120 ++ .../threadblock/default_dq_mma_multistage.h | 289 ++++ .../threadblock/default_dq_mma_pipelined.h | 245 +++ .../gemm/threadblock/default_mma.h | 283 +++ .../gemm/threadblock/default_mma_bf16.h | 345 ++++ .../gemm/threadblock/dq_mma_base.h | 237 +++ .../gemm/threadblock/dq_mma_multistage.h | 107 ++ .../dq_mma_multistage_finegrained.h | 634 +++++++ .../threadblock/dq_mma_multistage_percol.h | 586 +++++++ .../gemm/threadblock/dq_mma_pipelined.h | 379 +++++ .../gemm/warp/default_mma_tensor_op.h | 103 ++ .../warp/mma_tensorop_compute_B_with_f16.h | 283 +++ .../gemm/warp/mma_tensorop_dequantizer.h | 534 ++++++ .../moe/cutlass_extensions/gemm_configs.h | 125 ++ .../interleaved_numeric_conversion.h | 392 +++++ .../tile_interleaved_layout.h | 2 +- .../fine_grained_scale_zero_iterator.h | 222 +++ .../cutlass_extensions/weight_only_quant_op.h | 50 + .../cuda/moe/ft_moe/cutlass_heuristic.cc | 4 +- .../cuda/moe/ft_moe/cutlass_heuristic.h | 2 +- .../cuda/moe/ft_moe/ft_gemm_configs.h | 58 - .../cuda/moe/ft_moe/moe_cutlass_kernel.h | 463 ----- .../cuda/moe/ft_moe/moe_gemm_kernels.h | 10 +- .../moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu | 11 +- .../moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu | 30 + .../moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu | 12 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 139 +- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 168 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 92 +- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 94 +- .../cuda/quantization/moe_quantization.cc | 143 ++ .../cuda/quantization/moe_quantization.h | 25 + .../core/graph/contrib_ops/contrib_defs.cc | 58 + onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + onnxruntime/test/contrib_ops/moe_test.cc | 1516 ++++++++++++----- .../transformers/test_parity_mixtral_moe.py | 6 +- .../python/transformers/test_parity_moe.py | 4 +- 60 files changed, 10748 insertions(+), 1497 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions}/compute_occupancy.h (62%) rename onnxruntime/contrib_ops/cuda/moe/{ft_moe/gemm_moe_problem_visitor.h => cutlass_extensions/epilogue/thread/fused_activations.h} (57%) create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions}/epilogue_helpers.h (57%) create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h rename onnxruntime/contrib_ops/cuda/moe/{ft_moe/layout_traits_helper.h => cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h} (71%) create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions/gemm/kernel}/moe_problem_visitor.h (79%) create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h rename onnxruntime/contrib_ops/cuda/moe/{ft_moe => cutlass_extensions}/tile_interleaved_layout.h (98%) create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h delete mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h delete mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu create mode 100644 onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc create mode 100644 onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index cadb06bb38707..0051f241e4f9b 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -60,6 +60,8 @@ set(contrib_ops_excluded_files "quantization/matmul_nbits.cc" "quantization/matmul_nbits.cuh" "quantization/matmul_nbits.cu" + "quantization/moe_quantization.h" + "quantization/moe_quantization.cc" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 32a4ca16b7824..9b45cc02708d6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -78,6 +78,7 @@ Do not modify directly.* * com.microsoft.QLinearSigmoid * com.microsoft.QLinearSoftmax * com.microsoft.QLinearWhere + * com.microsoft.QMoE * com.microsoft.QOrderedAttention * com.microsoft.QOrderedGelu * com.microsoft.QOrderedLayerNormalization @@ -4261,6 +4262,69 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.QMoE** + + Int4 MoE + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation_type : string
+
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
k : int
+
Number of top experts to select from expert pool
+
normalize_routing_weights : int
+
Whether to normalize routing weights
+
+ +#### Inputs (7 - 11) + +
+
input : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
router_probs : T
+
2D input tensor with shape (num_rows, num_experts)
+
fc1_experts_weights : T1
+
3D input tensor with shape (num_experts, hidden_size, inter_size / 2)
+
fc1_scales : T
+
2D input tensor with shape (num_experts, inter_size)
+
fc1_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc2_experts_weights : T1
+
3D input tensor with shape (num_experts, inter_size, hidden_size / 2)
+
fc2_scales : T
+
2D input tensor with shape (num_experts, hidden_size)
+
fc2_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, hidden_size)
+
fc3_experts_weights (optional) : T1
+
3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)
+
fc3_scales (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
fc3_experts_bias (optional) : T
+
2D optional input tensor with shape (num_experts, inter_size)
+
+ +#### Outputs + +
+
output : T
+
2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float16)
+
Constrain input and output types to float or float16 tensors.
+
T1 : tensor(uint8)
+
Constrain weights type to uint8 tensors.
+
+ + ### **com.microsoft.QOrderedAttention** Quantized version of simplified Multi-Head Self Attention(using int8 with specific matrix Layout). diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bca8e17b3dfd4..c963781435465 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,6 +868,7 @@ Do not modify directly.* |PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 2efc37cf98010..1dbbe8c4e7eaa 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -18,23 +18,15 @@ namespace cuda { #if defined(ORT_USE_NCCL) -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ShardedMoE, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ShardedMoE, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ ShardedMoE); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) -using namespace ONNX_NAMESPACE; - template ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("tensor_shards", &tensor_shards_).IsOK()); @@ -69,25 +61,23 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params(tensor_shards_); - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, - fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional)); + MoEQuantType quant_type = MoEQuantType::None; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); - ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, - "num_experts should be divisible by world_size"); + ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) { ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); } - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, - fc3_experts_weights_optional != nullptr, + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, normalize_routing_weights_); - size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), - static_cast(moe_params.num_experts), static_cast(k_)); + size_t ws_size = moe_runner.getWorkspaceSize( + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); @@ -107,30 +97,29 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { const CudaT* fc_scales_ptr = nullptr; - moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), - reinterpret_cast(router_probs->template Data()), - reinterpret_cast(fc1_experts_weights->template Data()), - std::move(fc_scales_ptr), - fc1_experts_bias_optional == nullptr - ? nullptr - : reinterpret_cast(fc1_experts_bias_optional->template Data()), - activation_type_, - fc3_experts_weights_optional == nullptr - ? nullptr - : reinterpret_cast(fc3_experts_weights_optional->template Data()), - std::move(fc_scales_ptr), - fc3_experts_bias_optional == nullptr - ? nullptr - : reinterpret_cast(fc3_experts_bias_optional->template Data()), - reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc_scales_ptr), static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), - static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), - static_cast(k_), reinterpret_cast(work_space.get()), - reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), - reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), Stream(context)); + moe_runner.run_moe_fc( + reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), std::move(fc_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, + fc3_experts_weights_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->template Data()), + std::move(fc_scales_ptr), + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->template Data()), std::move(fc_scales_ptr), + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); Tensor* output = context->Output(0, input->Shape()); @@ -146,12 +135,8 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(moe_params.tensor_shards == nccl_->Size()); NCCL_RETURN_IF_ERROR(ncclGroupStart()); NCCL_RETURN_IF_ERROR(ncclAllReduce(reinterpret_cast(fc2_output.get()), - reinterpret_cast(fc2_output_bc.get()), - fc2_output_size / sizeof(CudaT), - GetNcclDataType(input->DataType()), - ncclSum, - nccl_->Comm(), - Stream(context))); + reinterpret_cast(fc2_output_bc.get()), fc2_output_size / sizeof(CudaT), + GetNcclDataType(input->DataType()), ncclSum, nccl_->Comm(), Stream(context))); NCCL_RETURN_IF_ERROR(ncclGroupEnd()); } @@ -166,19 +151,12 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { NCCL_RETURN_IF_ERROR(ncclGroupStart()); for (int rank = 0; rank < nccl_->Size(); ++rank) { int64_t experts_start_index = rank_to_experts_start_index_[rank]; - moe_runner.get_total_rows_info(experts_start_index, - moe_params.local_num_experts, - total_past_rows, + moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows, total_covered_rows); const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; - NCCL_RETURN_IF_ERROR(ncclBroadcast(src, - dst, - total_covered_rows * stride_count, - GetNcclDataType(input->DataType()), - rank, - nccl_->Comm(), - Stream(context))); + NCCL_RETURN_IF_ERROR(ncclBroadcast(src, dst, total_covered_rows * stride_count, + GetNcclDataType(input->DataType()), rank, nccl_->Comm(), Stream(context))); } NCCL_RETURN_IF_ERROR(ncclGroupEnd()); } @@ -197,8 +175,7 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { } template -Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, - OpKernelContext* context, +Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, OpKernelContext* context, cudaEvent_t& cuda_event) const { if (rank_to_experts_start_index_[0] != std::numeric_limits::min()) { return Status::OK(); @@ -215,23 +192,16 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, IAllocator::MakeUniquePtr(allocator, nccl_->Size(), false, stream); // Only happens in the first run. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), - &local_experts_start_index_, - IndexTypeSize, - cudaMemcpyHostToDevice, - Stream(context))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), &local_experts_start_index_, IndexTypeSize, + cudaMemcpyHostToDevice, Stream(context))); NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast(experts_start_index_d.get()), - reinterpret_cast(rank_to_experts_start_index_d.get()), - 1, - GetNcclDataType(DataTypeImpl::GetType()), - nccl_->Comm(), + reinterpret_cast(rank_to_experts_start_index_d.get()), 1, + GetNcclDataType(DataTypeImpl::GetType()), nccl_->Comm(), Stream(context))); // The const_cast<> violates the const modifier to make sure the synchronization happens only once per session. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(rank_to_experts_start_index_.data()), - rank_to_experts_start_index_d.get(), - nccl_->Size() * IndexTypeSize, - cudaMemcpyDeviceToHost, - Stream(context))); + rank_to_experts_start_index_d.get(), nccl_->Size() * IndexTypeSize, + cudaMemcpyDeviceToHost, Stream(context))); CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming)); CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context))); diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 57e951d3a68ff..3621ffc5c64ca 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -72,6 +72,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -275,6 +276,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000000..07c38c58e446a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h similarity index 62% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h index 86136ea244e23..99cbe4a66049e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h @@ -26,19 +26,22 @@ namespace ort_fastertransformer { template inline int compute_occupancy_for_kernel() { - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); if (smem_size > (48 << 10)) { - cudaError_t status = - cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (status == cudaError::cudaErrorInvalidValue) { - // Clear the error bit since we can ignore this. - // This should mean that smem_size > cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an - // occupancy of 0. This will cause the heuristic to ignore this configuration. - status = cudaGetLastError(); + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. return 0; } - CUDA_CALL_THROW(status); } int max_active_blocks = -1; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h similarity index 57% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h index 311ed323cb90c..da8cb6d294efd 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/gemm_moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h @@ -28,52 +28,68 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ - /*! \file - \brief Scheduler for grouped GEMM + \brief Functor performing linear combination with a maximum operation used by epilogues. */ #pragma once +#include "cutlass/array.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/matrix_coord.h" - -#include "moe_problem_visitor.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { -namespace gemm { -namespace kernel { +namespace epilogue { +namespace thread { -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct GemmMoeProblemVisitor - : public MoeProblemVisitor, ThreadblockShape, - GroupScheduleMode_, PrefetchTileCount, ThreadCount> { - static bool const kTransposed = Transposed; +///////////////////////////////////////////////////////////////////////////////////////////////// - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; - using Base = - MoeProblemVisitor; - using Params = typename Base::Params; - using SharedStorage = typename Base::SharedStorage; +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} - // - // Methods - // - CUTLASS_DEVICE - GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, shared_storage_, block_idx) {} -}; +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} ///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor { + static bool const kIsHeavy = true; + + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = static_cast(0.7978845608028654); + float k1 = static_cast(0.044715); + + return static_cast( + cutlass::constants::half() * z * + (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } +}; -} // namespace kernel -} // namespace gemm +} // namespace thread +} // namespace epilogue } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 0000000000000..affd1d83a35de --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "tensorrt_llm/common/quantization.h" + +namespace tk = tensorrt_llm::common; + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + explicit Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + explicit Params(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; + + /// Shared storage + struct SharedStorage {}; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, + AlphaScaleElementType* ptr_alpha_row, AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + per_token_quant_(quant_option.hasPerTokenScaling()), + per_channel_quant_(quant_option.hasPerChannelScaling()), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), + iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) { + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 0000000000000..40f126d56616a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,247 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = + Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + + private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) : stride_((ref.stride(0) / LoadType::kElements)) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * static_cast(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + group * ThreadMap::Delta::kGroup * stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + pointer_offset / LoadType::kElements; + + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + int vector_idx = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h similarity index 57% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h index b18a70e899d1c..b784646c31f84 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h @@ -1,11 +1,12 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -24,139 +25,85 @@ #pragma once -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/functional.h" -#include "cutlass/half.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue/thread/fused_activations.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_generic.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/epilogue/thread/linear_combination_silu.h" -namespace cutlass { -namespace epilogue { -namespace thread { - -__forceinline__ __device__ float copysignf_pos(float a, float b) { - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} - -__forceinline__ __device__ float tanh_opt(float x) { -#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - const float exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); -#else - return fast_tanh(x); -#endif -} - -template <> -struct GELU_taylor { - static const bool kIsHeavy = true; - CUTLASS_DEVICE - float operator()(float const& z) const { - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); - - return float( - cutlass::constants::half() * z * - (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const { return this->operator()(scalar); } -}; - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - namespace ort_fastertransformer { struct EpilogueOpBiasSilu {}; -struct EpilogueOpNoBiasSilu {}; - struct EpilogueOpBiasReLU {}; -struct EpilogueOpNoBiasReLU {}; - struct EpilogueOpBiasFtGelu {}; -struct EpilogueOpNoBiasFtGelu {}; +struct EpilogueOpDefaultSilu {}; + +struct EpilogueOpDefaultReLU {}; + +struct EpilogueOpDefaultFtGelu {}; struct EpilogueOpBias {}; -struct EpilogueOpNoBias {}; +struct EpilogueOpDefault {}; template struct Epilogue {}; +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + template struct Epilogue { using Op = cutlass::epilogue::thread::LinearCombinationSilu; + ElementAccumulator, BiasScaleMode>; }; template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationSilu; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationRelu; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, BiasScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationRelu; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; }; +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationGeneric< - cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, - ElementAccumulator, cutlass::epilogue::thread::ScaleType::NoBetaScaling, - cutlass::FloatRoundStyle::round_to_nearest, true>; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; }; template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombinationGeneric< - cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, - ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling, - cutlass::FloatRoundStyle::round_to_nearest, true>; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; }; template -struct Epilogue { - using Op = cutlass::epilogue::thread::LinearCombination; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, ElementType, ElementsPerVectorAccess, ElementAccumulator, + ElementAccumulator, DefaultScaleMode, cutlass::FloatRoundStyle::round_to_nearest, true>; }; template -struct Epilogue { - using Op = - cutlass::epilogue::thread::LinearCombination< - ElementType, ElementsPerVectorAccess, ElementAccumulator, - ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; }; } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 0000000000000..f5064afc23ae0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h new file mode 100644 index 0000000000000..b226b73e86fe1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -0,0 +1,476 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, + int64_t* splitk_buffer_offsets) { + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different hidden_size (=m*n) + // so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) { + sum += static_cast(in_tensor_[k_idx * hidden_size + i]); + } + out_tensor_[i] = (T_OUT)(sum); + } +} + +/// GEMM Grouped +template +class BaseSplitkGrouped { + public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + + protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + + private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, int32_t tile_count, void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, args.problem_count, args.threadblock_count, + reinterpret_cast(host_workspace.data())); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + + public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { return BaseKernel::can_implement(args); } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) { + total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); + } + size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, args.problem_count, + args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { return dim3(args.threadblock_count, 1, 1); } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + BaseKernel::kThreadCount, smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, + int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient(cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, + int available_sm_count = -1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); + return 0; + } + + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) { + available_sm_count = multiprocessor_count; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + gemm_params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } else { + gemm_params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + if (!gemm_params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm + { + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel<<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Launch splitkReduction + { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, + gemm_params_.split_k_slices, gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class SplitkGemmGrouped : public BaseSplitkGrouped { + public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h similarity index 71% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h index eb33a98e4246f..2b3478a38fc2e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -1,11 +1,12 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,51 +14,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -/* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. - - Note that for int4, ThreadBlockK MUST be 64. - - */ - #pragma once -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" -#include "cutlass/platform/platform.h" +#include "cutlass/bfloat16.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" namespace cutlass { namespace gemm { namespace kernel { -template -struct LayoutDetailsB {}; - -// Volta specialiations. Volta will dequantize before STS, so we need a different operator -template -struct LayoutDetailsB { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 8; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 64; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - template struct MixedGemmArchTraits {}; @@ -66,7 +38,7 @@ struct MixedGemmArchTraits { static constexpr int Stages = 2; using OperatorClass = cutlass::arch::OpClassSimt; using AccType = float; - using LayoutB = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; static constexpr int ElementsPerAccessA = 1; static constexpr int ElementsPerAccessB = 1; @@ -80,10 +52,13 @@ struct MixedGemmArchTraits { // ========================= Volta Traits =========================== // Volta will always dequantize after the global memory load. // This will instantiate any HMMA tensorcore kernels for Volta. +// Note that volta does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm70, - typename cutlass::platform::enable_if::value>::type> { + typename cutlass::platform::enable_if::value || + cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -103,10 +78,13 @@ struct MixedGemmArchTraits< }; // ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm75, - typename cutlass::platform::enable_if::value>::type> { + typename cutlass::platform::enable_if::value || + cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -129,7 +107,8 @@ struct MixedGemmArchTraits< template struct MixedGemmArchTraits< TypeA, TypeB, cutlass::arch::Sm80, - typename cutlass::platform::enable_if::value>::type> { + typename cutlass::platform::enable_if::value || + cutlass::platform::is_same::value>::type> { private: using LayoutDetails = LayoutDetailsB; @@ -150,4 +129,4 @@ struct MixedGemmArchTraits< } // namespace kernel } // namespace gemm -} // namespace cutlass \ No newline at end of file +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 0000000000000..fe4bc0940d9e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h new file mode 100644 index 0000000000000..9339be92dfb2a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "cutlass/layout/permute.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void> +struct DefaultSplitkGemmGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout> +struct DefaultSplitkGemmGrouped::value>::type> { + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = + kernel::detail::MapArguments; + + // Define the default GEMM kernel + using DefaultGemmKernel = + typename kernel::DefaultGemm::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::SplitkGemmGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000000..778d45f39eab3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,513 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), + group_size(group_size), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_zero(ref_zero), + ref_C(ref_C), + ref_D(ref_D), + batch_count(serial_split_k_factor), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), + group_size(args.group_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.ref_A.layout()), + ref_A(args.ref_A), + params_B(args.ref_B.layout()), + ref_B(args.ref_B), + params_scale(args.ref_scale.layout()), + ref_scale(args.ref_scale), + ref_zero(args.ref_zero), + params_C(args.ref_C.layout()), + ref_C(args.ref_C), + params_D(args.ref_D.layout()), + ref_D(args.ref_D), + output_op(args.output_op), + semaphore(static_cast(workspace)), + gemm_k_size(gemm_k_size), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) {} + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const& args) { + static int const kAlignmentA = + (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, + typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, + int group_size) { + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), {params.problem_size.m(), problem_size_k}, + thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, tb_offset_B, params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h new file mode 100644 index 0000000000000..6cb5cc4e1334c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -0,0 +1,66 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> { + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = + MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000..fb35b2dbf12cf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,516 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = tensorrt_llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, TensorRefB ref_B_, + tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, int64_t batch_stride_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), + problem_size(problem_size_), + batch_count(batch_count_), + ref_A(ref_A_), + ref_B(ref_B_), + quant_option(quant_option_), + ref_alpha_col(ref_alpha_col_), + ref_alpha_row(ref_alpha_row_), + ref_C(ref_C_), + ref_D(ref_D_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_D(0), + epilogue_visitor(epilogue_visitor_) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_alpha_col(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_alpha_col(nullptr), + ptr_alpha_row(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0) {} + + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_alpha_col(args.ref_alpha_col.layout()), + params_alpha_row(args.ref_alpha_col.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + quant_option(args.quant_option), + ptr_alpha_col(args.ref_alpha_col.data()), + ptr_alpha_row(args.ref_alpha_row.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), thread_idx, warp_idx, + lane_idx, params.params_alpha_col, params.params_C, params.params_D, params.quant_option, params.ptr_alpha_row, + params.ptr_alpha_col, params.ptr_C, params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert(false, + "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000000..35d22b2f55a89 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,126 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB {}; + +// Volta specialiations. Volta will dequantize before STS, so we need a different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template + struct LayoutDetailsB < + uint8_t, + Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template + struct LayoutDetailsB < + uint4b_t, + Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 90>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 90>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h new file mode 100644 index 0000000000000..9e3e9d20d7f6e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -0,0 +1,471 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type {}; + +template +struct use_dq_gemm> : platform::true_type {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + int problem_count; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t* total_rows_before_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + weight_scales(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + total_rows_before_expert(nullptr), + gemm_n(0), + gemm_k(0), + host_problem_sizes(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, + ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, + GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count), + threadblock_count(threadblock_count), + group_size(group_size), + output_op(output_op), + ptr_A(const_cast(ptr_A)), + ptr_B(const_cast(ptr_B)), + weight_scales(const_cast(weight_scales)), + ptr_C(const_cast(ptr_C)), + ptr_D(ptr_D), + total_rows_before_expert(total_rows_before_expert), + gemm_n(gemm_n), + gemm_k(gemm_k), + host_problem_sizes(nullptr) { + if (platform::is_same::value || platform::is_same::value) { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} + + CUTLASS_HOST_DEVICE + explicit Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, + tile_count), + threadblock_count(args.threadblock_count), + group_size(args.group_size), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + weight_scales(args.weight_scales), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D) {} + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, + args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { + if (platform::is_same::value || platform::is_same::value) { + if (args.weight_scales == nullptr) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } else if (args.weight_scales != nullptr) { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } else if (args.group_size != args.gemm_k) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } else if (static_cast(args.gemm_n) < Mma::IteratorB::AccessType::kElements) { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 || + platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) { + loop++; + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset(static_cast(cta_idx / grid_shape.n()) * Mma::Shape::kM, + static_cast(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = (reinterpret_cast(params.ptr_B)) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B = + platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + + if constexpr (use_dq_gemm::value) { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, + scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } else { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + LayoutC layout_C(0); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + run_kernel(params, + shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + // static_assert(false, + // "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); + ; +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h similarity index 79% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index 1de8f6b69642c..6852d4c811b4d 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -1,33 +1,19 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * http://www.apache.org/licenses/LICENSE-2.0 * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /*! \file \brief Base scheduler for grouped problems, using MoE @@ -106,7 +92,7 @@ struct BaseMoeProblemVisitor { /// Get the grid shape CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) { return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); } @@ -145,9 +131,9 @@ struct BaseMoeProblemVisitor { } CUTLASS_HOST_DEVICE - static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { return ProblemSizeHelper::tile_count(grid); } + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) { return ProblemSizeHelper::tile_count(grid); } - static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) { int32_t total_tiles = 0; for (int32_t i = 0; i < problem_count; ++i) { auto problem = host_problem_sizes_ptr[i]; @@ -276,13 +262,13 @@ struct MoeProblemVisitor +struct SplitkGemmGrouped { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = + kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = + GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, + typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, + ElementFinalOutput** ptr_C, ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + host_problem_sizes(host_problem_sizes), + split_k_slices(split_k_slices), + splitk_buffer_offsets(splitk_buffer_offsets) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // + // Methods + // + + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_C_split(nullptr), + ptr_D_split(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + swizzle_log_tile(0), + gemm_k_size(0), + host_problem_sizes(nullptr), + split_k_slices(1), + splitk_buffer_offsets(nullptr) {} + + CUTLASS_HOST_DEVICE + explicit(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), + host_problem_sizes(args.host_problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + ptr_C_split(reinterpret_cast(workspace)), + ptr_D_split(reinterpret_cast(workspace)), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd), + split_k_slices(args.split_k_slices), + splitk_buffer_offsets(args.splitk_buffer_offsets) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.host_problem_sizes[0], {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } + + static Status can_implement(Arguments const& args) { return Status::kSuccess; } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA* ptr_A = + reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB* ptr_B = + reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + cutlass::gemm::GemmCoord threadblock_offset(static_cast(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + static_cast(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) { + problem_size_k = problem_size.k(); + } else { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset_C); + + iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, + threadblock_offset_C); + iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000000..8bbc1ee4e6c47 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,120 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters {}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = + NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + NumericArrayConverter; + + using TransformAfterLDS = + FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000000..8b9d6b0b14add --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,289 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIterators; + +// Fine grained iterators +template +struct DefaultScaleIterators> { + using IteratorScale = + cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIterators> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, Element, Layout, 0, + IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for elementA + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator_, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && + !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, + layout::RowMajor, OperatorClass, std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + using ScaleIterators = + DefaultScaleIterators; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, + typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>; +}; + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator_, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && + layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, layout::ColumnMajor, + ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = + DefaultScaleIterators; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, + typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, + layout::RowMajor, typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000000..91c4cd342569e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,245 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, + ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, ElementScale, + LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, + IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = typename platform::conditional::type; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, + ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = + transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, ElementScale, + LayoutScale, 0, IteratorScaleThreadMap, kAlignmentScale>; + + using SmemScaleType = typename platform::conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator, + SmemScaleType, LayoutScale, 0, IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, + IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000000..1a3e7e39c9656 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,283 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000000..4afd482f85628 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,345 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000000..cf5ba6faa0c82 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, + typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) { + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000000..f11e94d9d2b95 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 0000000000000..dd934b9a00369 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,634 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = + reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = + reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, IteratorScale& iterator_scale, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, iterator_scale, group_start_iteration_A, + group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, iterator_scale, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h new file mode 100644 index 0000000000000..33bcb19106381 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -0,0 +1,586 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill(dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000000..2c85ba8a1995e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,379 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = bool> +class DqMmaPipelined : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = + warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined( + typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation + ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this + ///< argument is not added, it does not affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = + NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, + warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000000..f0b6f4fcaad33 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,103 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 8 * sizeof_bits::value / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000000..a368c6d220266 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); + + static_assert(platform::is_same::value || + (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = + MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = + MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000000..51ca8282e42ff --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 80 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) {} + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) {} + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, + FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast(pointer_ + ColsPerMmaTile * tile_iter); + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } + + private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer< + MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const base_col = lane_idx & 0xF8 + lane_idx % 4; + int const thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value inside + // of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { + scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h new file mode 100644 index 0000000000000..0841218a480ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace ort_fastertransformer { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, +}; + +enum class MainloopScheduleType { + AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. +}; + +enum class EpilogueScheduleType { + AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +}; + +enum class ClusterShape { ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1 }; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + + CutlassGemmConfig() {} + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages) {} + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90), + mainloop_schedule(mainloop_schedule), + epilogue_schedule(epilogue_schedule), + cluster_shape(cluster_shape) {} +}; + +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000000..7fd1745aa2c54 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,392 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter {}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = + __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h similarity index 98% rename from onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h rename to onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h index 3505bea24e4d9..e5abefa35bc84 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/tile_interleaved_layout.h @@ -42,7 +42,7 @@ namespace cutlass { namespace layout { template -class ColumnMajorTileInterleave { +struct ColumnMajorTileInterleave { static constexpr int kRowsPerTile = RowsPerTile; static constexpr int kColumnsInterleaved = ColumnsInterleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 0000000000000..79811ef3e611b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + explicit Params(Layout const& layout) : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), + pointer_scale_(reinterpret_cast(const_cast(pointer_scale))), + pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = + threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator(params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), + group_size) {} + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { is_valid_ &= (!enable); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return is_valid_; } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { return reinterpret_cast(pointer_scale_); } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { return reinterpret_cast(pointer_zero_); } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 0000000000000..403221a956017 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,50 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass { + +enum class WeightOnlyQuantOp { UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS }; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) { return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; } + +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index adc043e5689e2..cd59e904ad9eb 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -151,8 +151,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector(ctas_per_wave); + const float current_score = static_cast(num_waves_total) - num_waves_fractional; const float score_slack = 0.1f; if (current_score < config_score || diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h index e70efe0503b55..0f75a121b3b92 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h @@ -16,7 +16,7 @@ #pragma once -#include "ft_gemm_configs.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h" #include #include diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h deleted file mode 100644 index a5faad423fad9..0000000000000 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -namespace ort_fastertransformer { -// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape -// in the kernel layout details when doing weight only quantization. -enum class CutlassTileConfig { - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, - - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, - - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, - - // Warp configs for M=128 - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape128x32x64 -}; - -enum class SplitKStyle { - NO_SPLIT_K, - SPLIT_K_SERIAL, - // SPLIT_K_PARALLEL // Not supported yet -}; - -struct CutlassGemmConfig { - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; -}; - -} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h deleted file mode 100644 index cfe306c2482a5..0000000000000 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h +++ /dev/null @@ -1,463 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" - -#include "gemm_moe_problem_visitor.h" -#include "tile_interleaved_layout.h" -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. -// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. -template -using void_t = void; - -template -struct use_dq_gemm : platform::false_type {}; - -template -struct use_dq_gemm> : platform::true_type {}; - -// SFINAE overload for dequantizing gemm -template ::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, - typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, - typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { - typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), weight_scale_ptr, - scale_extent, thread_idx, tb_offset_scale); - - mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_scale, src_accum); -} - -// SFINAE overload for normal gemm. This completely ignores the scale parameters -template ::value, bool>::type = true> -CUTLASS_DEVICE static void run_mma(Mma mma, int gemm_k_iterations, typename Mma::FragmentC& accum, - typename Mma::IteratorA iterator_A, typename Mma::IteratorB iterator_B, - typename Mma::FragmentC const& src_accum, ElementScale* weight_scale_ptr, - MatrixCoord scale_extent, const int thread_idx, MatrixCoord tb_offset_scale) { - mma(gemm_k_iterations, accum, iterator_A, iterator_B, src_accum); -} - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MoeFCGemm { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = false; - - // Optional transpose - using MapArguments = - kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - static_assert(!kTransposed, "Transpose problem not supported"); - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor = - GemmMoeProblemVisitor; - - // - // Structures - // - - /// Argument structure - struct Arguments { - // - // Data members - // - - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t* total_rows_before_expert; - int64_t gemm_n; - int64_t gemm_k; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0), - threadblock_count(0), - ptr_A(nullptr), - ptr_B(nullptr), - weight_scales(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - total_rows_before_expert(nullptr), - gemm_n(0), - gemm_k(0), - host_problem_sizes(nullptr) {} - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(int problem_count, int threadblock_count, typename EpilogueOutputOp::Params output_op, - const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, - ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, - GemmCoord* host_problem_sizes = nullptr) - : problem_count(problem_count), - threadblock_count(threadblock_count), - output_op(output_op), - ptr_A(const_cast(ptr_A)), - ptr_B(const_cast(ptr_B)), - weight_scales(const_cast(weight_scales)), - ptr_C(const_cast(ptr_C)), - ptr_D(ptr_D), - total_rows_before_expert(total_rows_before_expert), - gemm_n(gemm_n), - gemm_k(gemm_k), - host_problem_sizes(host_problem_sizes) { - if (platform::is_same::value || platform::is_same::value) { - assert(weight_scales); - } - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() : ptr_A(nullptr), ptr_B(nullptr), weight_scales(nullptr), ptr_C(nullptr), ptr_D(nullptr) {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.total_rows_before_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, - tile_count), - threadblock_count(args.threadblock_count), - output_op(args.output_op), - ptr_A(args.ptr_A), - ptr_B(args.ptr_B), - weight_scales(args.weight_scales), - ptr_C(args.ptr_C), - ptr_D(args.ptr_D) {} - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { - problem_visitor = typename ProblemVisitor::Params(args.total_rows_before_expert, args.gemm_n, args.gemm_k, - args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - weight_scales = args.weight_scales; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - } - }; - - /// Shared memory storage structure - union SharedStorage { - typename ProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - public: - // - // Methods - // - - CUTLASS_DEVICE - MoeFCGemm() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { return Status::kSuccess; } - - static Status can_implement(Arguments const& args) { - if (args.weight_scales != nullptr) { - CUTLASS_TRACE_HOST( - "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); - return Status::kInvalid; - } - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { - return 0; - } - - // The dummy template parameter is not used and exists so that we can compile this code using - // a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in - // a namespace - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) { CUTLASS_NOT_IMPLEMENTED(); } - }; - - template - struct KernelRunner { - CUTLASS_DEVICE - static void run_kernel(Params const& params, SharedStorage& shared_storage) { - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert(platform::is_same::value && kInterleave == 1 || - platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) { - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset(int(cta_idx / grid_shape.n()) * Mma::Shape::kM, - int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Load element pointers. Exchange pointers and strides if working on the transpose - const int64_t rows_to_jump = - problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*)params.ptr_B) + problem_idx * bytes_per_expert_matrix; - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B = - platform::is_same::value ? gemm_n : gemm_k * kInterleave; - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; - - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, - tb_offset_A); - - typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, - {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, - tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - run_mma(mma, gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, weight_scale_ptr, - {1, problem_size.n()}, thread_idx, tb_offset_scale); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = reinterpret_cast(params.ptr_C) + problem_idx * gemm_n; - ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - - LayoutC layout_C(0); - LayoutC layout_D(gemm_n); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params_C, ptr_C, problem_size.mn(), thread_idx, - threadblock_offset.mn()); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params_D, ptr_D, problem_size.mn(), thread_idx, - threadblock_offset.mn()); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // Next tile - problem_visitor.advance(gridDim.x); - } - } - }; - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - static constexpr bool compile_needed = platform::is_same::value; - KernelRunner::run_kernel(params, shared_storage); -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index e0f91ab806c85..7e29dde8f897b 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -16,8 +16,8 @@ #pragma once +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm_configs.h" #include -#include "ft_gemm_configs.h" namespace ort_fastertransformer { @@ -42,13 +42,9 @@ class MoeGemmRunner { int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream); - void moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, - int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - ActivationType activation_type, cudaStream_t stream); - void moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, cudaStream_t stream); + int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cudaStream_t stream); private: template diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d9a249db4237..15cab9dd4a9bf 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -13,9 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif -#include "moe_gemm_kernels_template.h" +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 0000000000000..1309a7c32a37a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7b250e6ca9060..0277fab9df95c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -13,8 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif -#include "moe_gemm_kernels_template.h" +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif namespace ort_fastertransformer { template class MoeGemmRunner; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 2a15fdfd1cc1a..d81808e217fbc 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -26,21 +26,22 @@ #pragma warning(disable : 4100) #endif +#include "cutlass/arch/arch.h" #include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/arch/arch.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" -#include "compute_occupancy.h" -#include "epilogue_helpers.h" -#include "layout_traits_helper.h" -#include "moe_cutlass_kernel.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "contrib_ops/cuda/moe/cutlass_extensions/gemm/threadblock/default_mma.h" #if defined(_MSC_VER) #pragma warning(pop) @@ -67,10 +68,6 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { - if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) { - ORT_THROW("[FT Error][MoeGemm] Grouped gemm does not support split-k"); - } - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, "Specialized for half, float"); @@ -87,10 +84,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w using CutlassWeightType_ = typename cutlass::platform::conditional::value, cutlass::half_t, WeightType>::type; + using CutlassWeightType = CutlassWeightType_; - // We need separate config for each architecture since we will target different tensorcore instructions. For float, - // we do not target TCs. + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; using ElementAccumulator = typename MixedGemmArchTraits::AccType; @@ -119,17 +117,17 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w return; } int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - if (occupancy == 0) { - ORT_THROW("[FT Error][MoE Runner] GPU lacks the shared memory resources to run GroupedGEMM kernel"); - } - const int threadblock_count = multi_processor_count * occupancy; + ORT_ENFORCE(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; - typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), ElementAccumulator(0.f)); + typename EpilogueOp::Params epilogue_op(ElementAccumulator(1.f), + biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + int const group_size = gemm_k; typename GemmGrouped::Arguments args( - num_experts, threadblock_count, epilogue_op, reinterpret_cast(A), - reinterpret_cast(B), reinterpret_cast(weight_scales), - reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, + num_experts, threadblock_count, group_size, epilogue_op, reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(weight_scales), + reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); GemmGrouped gemm; @@ -231,11 +229,28 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t /*total_rows*/, - int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, - int /*sm_version*/, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) { + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, + int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, + int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } + break; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + ORT_ENFORCE(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) { + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, 64>>( + A, B, weight_scales, biases, C, total_rows_before_expert, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, stream, occupancy); + } + break; case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, C, @@ -255,13 +270,13 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig gemm_config, multi_processor_count, stream, occupancy); break; case CutlassTileConfig::Undefined: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config undefined."); + ORT_THROW("GEMM config undefined."); break; case CutlassTileConfig::ChooseWithHeuristic: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] gemm config should have already been set by heuristic."); + ORT_THROW("GEMM config should have already been set by heuristic."); break; default: - ORT_THROW("[FT Error][dispatch_moe_gemm_to_cutlass] Config is invalid for same type MoE tensorop GEMM."); + ORT_THROW("Config is invalid for same type tensorop GEMM."); break; } } @@ -404,51 +419,20 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp cudaStream_t stream) { switch (activation_type) { case ActivationType::Relu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); break; case ActivationType::Gelu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); break; case ActivationType::Silu: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, + gemm_k, num_experts, stream); break; case ActivationType::Identity: - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); - break; - case ActivationType::InvalidType: - ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); - break; - default: { - ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); - } - } -} - -template -void MoeGemmRunner::moe_gemm_act(const T* A, const WeightType* B, const T* weight_scales, - T* C, int64_t* total_rows_before_expert, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, - ActivationType activation_type, cudaStream_t stream) { - switch (activation_type) { - case ActivationType::Relu: - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); - break; - case ActivationType::Gelu: - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); - break; - case ActivationType::Silu: - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, - gemm_k, num_experts, stream); - break; - case ActivationType::Identity: - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); break; case ActivationType::InvalidType: ORT_THROW("[FT Error][MoE Runner] Invalid activation type for MoE GEMM"); @@ -461,15 +445,10 @@ void MoeGemmRunner::moe_gemm_act(const T* A, const WeightType* B, template void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, - T* C, int64_t* total_rows_before_expert, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { - if (biases != nullptr) { - run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); - } else { - run_gemm(A, B, weight_scales, nullptr, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, - num_experts, stream); - } + T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, + int64_t gemm_k, int num_experts, cudaStream_t stream) { + run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, + num_experts, stream); } } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index c299cdcfe6a3d..360c0aacd9c7a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,11 +16,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include -#include // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ @@ -54,8 +54,8 @@ static constexpr int WARP_SIZE = 32; // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. template -__launch_bounds__(TPB) __global__ void moe_softmax(const T* input, const bool* finished, T* output, - const int num_cols) { +__launch_bounds__(TPB) __global__ + void moe_softmax(const T* input, const bool* finished, T* output, const int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -112,9 +112,9 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T*, const bool*, T*, int* } #else template -__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, - int* indices, int* source_rows, int num_experts, int k, - bool normalize_routing_weights) { +__launch_bounds__(TPB) __global__ + void moe_top_k(const T* inputs_after_softmax, const bool* finished, T* output, int* indices, int* source_rows, + int num_experts, int k, bool normalize_routing_weights) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -398,15 +398,14 @@ void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topk_gating_softmax - <<>>(input, finished, output, num_rows, indices, source_row, k, - normalize_routing_weights); + topk_gating_softmax<<>>( + input, finished, output, num_rows, indices, source_row, k, normalize_routing_weights); } template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_output, - int* indices, int* source_row, int num_rows, int num_experts, - int k, bool normalize_routing_weights, cudaStream_t stream) { + int* indices, int* source_row, int num_rows, int num_experts, int k, + bool normalize_routing_weights, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; switch (num_experts) { @@ -453,9 +452,8 @@ void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* default: { static constexpr int TPB = 256; moe_softmax<<>>(input, finished, softmax_temp_output, num_experts); - moe_top_k - <<>>(softmax_temp_output, finished, output, indices, source_row, num_experts, k, - normalize_routing_weights); + moe_top_k<<>>(softmax_temp_output, finished, output, indices, source_row, + num_experts, k, normalize_routing_weights); } } } @@ -522,8 +520,8 @@ __global__ void compute_total_rows_before_expert_kernel(const int* sorted_expert total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } -__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, - int local_num_experts, int local_experts_start_index) { +__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, + int local_experts_start_index) { const int expert = blockIdx.x * blockDim.x + threadIdx.x; const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; @@ -540,8 +538,7 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, - bool has_fc3, +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights) : has_fc3_(has_fc3), total_past_rows_(0), @@ -596,12 +593,12 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, const size_t padded_experts = pad_to_multiple_of_16(num_experts); const size_t num_moe_inputs = pad_to_multiple_of_16(k * num_rows); - source_rows_ = (int*)ws_ptr; + source_rows_ = reinterpret_cast(ws_ptr); permuted_rows_ = source_rows_ + num_moe_inputs; permuted_experts_ = permuted_rows_ + num_moe_inputs; - permuted_data_ = (T*)(permuted_experts_ + num_moe_inputs); + permuted_data_ = reinterpret_cast(permuted_experts_ + num_moe_inputs); - total_rows_before_expert_ = (int64_t*)(permuted_data_ + buf_size); + total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); if (has_fc3_) { fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); @@ -648,9 +645,7 @@ struct T2 { using Type = half2; }; -inline __device__ float2 operator*(const float2 a, const float2 b) { - return make_float2(a.x * b.x, a.y * b.y); -} +inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } inline __device__ float4 operator*(const float4 a, const float4 b) { return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); @@ -703,15 +698,13 @@ void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, c if (inter_size & 3 == 0) { using vec_type = typename T4::Type; int const threads = std::min(inter_size / 4, 1024); - elementWiseMulKernel<<>>(reinterpret_cast(output), - reinterpret_cast(input), - inter_size / 4); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 4); } else if (inter_size & 1 == 0) { using vec_type = typename T2::Type; int const threads = std::min(inter_size / 2, 1024); - elementWiseMulKernel<<>>(reinterpret_cast(output), - reinterpret_cast(input), - inter_size / 2); + elementWiseMulKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), inter_size / 2); } else { int const threads = std::min(inter_size, 1024); elementWiseMulKernel<<>>(output, input, inter_size); @@ -725,8 +718,7 @@ void CutlassMoeFCRunner::run_moe_fc( const T* fc3_scales, const T* fc3_expert_biases, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, - T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, - cudaStream_t stream) { + T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { static constexpr bool scales_required = std::is_same::value || std::is_same::value; @@ -750,8 +742,8 @@ void CutlassMoeFCRunner::run_moe_fc( source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); - sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, - permuted_rows_, k * num_rows, stream); + sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, + source_rows_, permuted_rows_, k * num_rows, stream); initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, @@ -765,22 +757,10 @@ void CutlassMoeFCRunner::run_moe_fc( dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, stream); } - // expanded_active_expert_rows is not used - if (fc1_expert_biases != nullptr) { - moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, - fc1_expert_weights, fc1_scales, fc1_expert_biases, - fc1_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, inter_size, hidden_size, - local_num_experts, fc1_activation_type, stream); - } else { - moe_gemm_runner_.moe_gemm_act(permuted_data_ + total_past_rows_ * hidden_size, - fc1_expert_weights, fc1_scales, - fc1_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, inter_size, hidden_size, - local_num_experts, fc1_activation_type, stream); - } + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, + fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + inter_size, hidden_size, local_num_experts, fc1_activation_type, stream); if (has_fc3_) { if (scales_required) { @@ -795,24 +775,31 @@ void CutlassMoeFCRunner::run_moe_fc( if (fc3_expert_weights == nullptr) { ORT_THROW("[FT Error][Run MoE FC] FC3 weights are null"); } - moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, - fc3_expert_weights, fc3_scales, fc3_expert_biases, - fc3_result_ + total_past_rows_ * inter_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, inter_size, hidden_size, - local_num_experts, stream); + moe_gemm_runner_.moe_gemm(permuted_data_ + total_past_rows_ * hidden_size, fc3_expert_weights, fc3_scales, + fc3_expert_biases, fc3_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + inter_size, hidden_size, local_num_experts, stream); elementWiseMul(fc1_result_ + total_past_rows_ * inter_size, fc3_result_ + total_past_rows_ * inter_size, static_cast(inter_size), static_cast(total_covered_rows_), stream); } - moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, - fc2_expert_weights, fc2_scales, nullptr, + moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, fc2_expert_weights, fc2_scales, nullptr, fc2_result + total_past_rows_ * hidden_size, - total_rows_before_expert_ + local_experts_start_index, - expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream); + total_rows_before_expert_ + local_experts_start_index, expanded_active_expert_rows, + hidden_size, inter_size, local_num_experts, stream); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 +template +void CutlassMoeFCRunner::run_moe_fc(const T*, const T*, const WeightType*, const T*, const T*, + ActivationType, const WeightType*, const T*, const T*, + const WeightType*, const T*, int, const int, const int, int, + int, int, int k, char*, T*, T*, int*, int*, cudaStream_t) { + // MoE gemm only supports Volta+ architectures + ORT_THROW("[FT Error][Run MoE FC] MoE gemm only supports Volta+ architectures"); +} +#else template void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, @@ -824,9 +811,9 @@ void CutlassMoeFCRunner::run_moe_fc( run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, fc3_expert_weights, fc3_scales, fc3_expert_biases, fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts, local_experts_start_index, k, workspace_ptr, fc2_result, - nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, - stream); + nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); } +#endif template void CutlassMoeFCRunner::compute_total_rows_before_expert(const int* sorted_indices, @@ -842,8 +829,8 @@ void CutlassMoeFCRunner::compute_total_rows_before_expert } template -void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, - int num_experts, int local_num_experts, +void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, int num_experts, + int local_num_experts, int local_experts_start_index, cudaStream_t stream) { total_rows_before_expert_host_.resize(num_experts); @@ -857,16 +844,15 @@ void CutlassMoeFCRunner::dispatch_activations(int64_t* to cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); cudaEventRecord(copy_event, stream); - dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, - local_num_experts, local_experts_start_index); + dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, local_num_experts, + local_experts_start_index); get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); } template void CutlassMoeFCRunner::get_total_rows_info(int64_t experts_start_index, - int64_t local_num_experts, - int64_t& total_past_rows, + int64_t local_num_experts, int64_t& total_past_rows, int64_t& total_covered_rows) { int64_t experts_end_index = experts_start_index + local_num_experts - 1; total_past_rows = 0; @@ -923,8 +909,8 @@ __global__ void initialize_moe_routing_kernel(const T* unpermuted_input, T* perm template void initialize_moe_routing_kernelLauncher(const T* unpermuted_input, T* permuted_output, const int* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, int num_rows, - int active_rows, int cols, int k, cudaStream_t stream) { + int* expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, + int cols, int k, cudaStream_t stream) { const int blocks = num_rows * k; const int threads = std::min(cols, 1024); initialize_moe_routing_kernel @@ -980,8 +966,8 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* const int expert_idx = expert_for_source_row[k_offset]; const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; - thread_output = thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + - (bias_ptr ? bias_ptr[tid] : T(0))); + thread_output = + thread_output + row_scale * (expanded_permuted_rows_row_ptr[tid] + (bias_ptr ? bias_ptr[tid] : T(0))); } reduced_row_ptr[tid] = thread_output; } @@ -991,8 +977,8 @@ __global__ void finalize_moe_routing_kernel(const T* expanded_permuted_rows, T* template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream) { + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream) { const int blocks = num_rows; const int threads = std::min(cols, 1024); finalize_moe_routing_kernel<<>>( @@ -1004,8 +990,8 @@ template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream) { + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream) { const int blocks = num_rows; const int threads = std::min(cols, 1024); finalize_moe_routing_kernel @@ -1017,8 +1003,8 @@ template void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, const T* skip_2, const T* bias, const T* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, int num_rows, int cols, - int k, cudaStream_t stream) { + const int* expert_for_source_row, int num_rows, int cols, int k, + cudaStream_t stream) { const int blocks = num_rows; const int threads = std::min(cols, 1024); if (skip_2 == nullptr) { @@ -1033,20 +1019,21 @@ void finalize_moe_routing_kernelLauncher(const T* expanded_permuted_rows, T* red } // ========================= TopK Softmax specializations =========================== -template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, - int, int, bool, cudaStream_t); -template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, - int, int, bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const float*, const bool*, float*, float*, int*, int*, int, int, int, + bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, + bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; // ===================== Specializations for init routing ========================= -template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, - int, int, cudaStream_t); -template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, - int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, + cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, + cudaStream_t); // ==================== Specializations for final routing =================================== template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, @@ -1054,15 +1041,12 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, - const int*, const int*, int, int, int, - cudaStream_t); + const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const float*, - const float*, const int*, const int*, int, int, int, - cudaStream_t); + const float*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, - const half*, const int*, const int*, int, int, int, - cudaStream_t); + const half*, const int*, const int*, int, int, int, cudaStream_t); } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index b13aab959fc48..dbd783c0cb11c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -13,26 +13,16 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - MoE, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - MoE); +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()).MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), MoE); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) -using namespace ONNX_NAMESPACE; - template -MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { -} +MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} template Status MoE::ComputeInternal(OpKernelContext* context) const { @@ -46,9 +36,10 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc3_experts_bias_optional = context->Input(7); MoEParameters moe_params; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, - fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional)); + MoEQuantType quant_type = MoEQuantType::None; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); @@ -56,14 +47,12 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, - fc3_experts_weights_optional != nullptr, + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, normalize_routing_weights_); - size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), - static_cast(moe_params.num_experts), static_cast(k_)); + size_t ws_size = moe_runner.getWorkspaceSize( + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); @@ -83,36 +72,28 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); const CudaT* fc_scales_ptr = nullptr; - moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), - reinterpret_cast(router_probs->template Data()), - reinterpret_cast(fc1_experts_weights->DataRaw()), - fc_scales_ptr, - fc1_experts_bias_optional == nullptr - ? nullptr - : reinterpret_cast(fc1_experts_bias_optional->template Data()), - activation_type_, - fc3_experts_weights_optional == nullptr - ? nullptr - : reinterpret_cast(fc3_experts_weights_optional->DataRaw()), - fc_scales_ptr, - fc3_experts_bias_optional == nullptr - ? nullptr - : reinterpret_cast(fc3_experts_bias_optional->template Data()), - reinterpret_cast(fc2_experts_weights->DataRaw()), - fc_scales_ptr, - static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), - static_cast(moe_params.num_experts), - static_cast(moe_params.local_num_experts), - 0 /*local_experts_start_index_ used in sharded MoE*/, - static_cast(k_), - reinterpret_cast(work_space.get()), - reinterpret_cast(fc2_output.get()), - reinterpret_cast(expert_scales.get()), - reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), - Stream(context)); + moe_runner.run_moe_fc( + reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->DataRaw()), fc_scales_ptr, + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, + fc3_experts_weights_optional == nullptr ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->DataRaw()), + fc_scales_ptr, + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->DataRaw()), fc_scales_ptr, + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), 0 /*local_experts_start_index_ used in sharded MoE*/, + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); Tensor* output = context->Output(0, input->Shape()); @@ -124,8 +105,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), - static_cast(moe_params.hidden_size), - static_cast(k_), Stream(context)); + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 84a5e8c7c120d..4ecf1a6206643 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -18,6 +18,11 @@ enum class MoEParallelType { EPAndTP = 3, }; +enum class MoEQuantType { + None = 0, + UINT4 = 1, +}; + struct MoEParameters { MoEParameters() {} explicit MoEParameters(int64_t tensor_shards) : tensor_shards(tensor_shards) {} @@ -33,14 +38,10 @@ struct MoEParameters { class MoEBase { public: - Status CheckInputs(MoEParameters& parameters, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, + Status CheckInputs(MoEParameters& parameters, MoEQuantType& quant_type, const Tensor* input, + const Tensor* router_probs, const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, const Tensor* fc3_experts_weights_optional, const Tensor* fc3_experts_bias_optional) const { const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); @@ -51,7 +52,7 @@ class MoEBase { int64_t hidden_size = input_dims[input_dims.size() - 1]; int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = fc1_experts_weights_dims[2]; + int64_t inter_size = fc2_experts_weights_dims[1]; if (fc1_experts_weights_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", @@ -69,20 +70,21 @@ class MoEBase { if (fc2_experts_weights_dims[1] != inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims[1] must be equal to inter_size, got ", - fc2_experts_weights_dims[1], - " and ", inter_size); + fc2_experts_weights_dims[1], " and ", inter_size); } - if (fc1_experts_weights_dims[2] != inter_size) { + + const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; + if (fc1_experts_weights_dims[2] != inter_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], - " and ", inter_size); + fc1_experts_weights_dims[2], " and ", inter_size); } - if (fc2_experts_weights_dims[2] != hidden_size) { + if (fc2_experts_weights_dims[2] != hidden_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", fc2_experts_weights_dims[2], " and ", hidden_size); } + if (router_probs_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", router_probs_dims.size()); @@ -105,25 +107,21 @@ class MoEBase { if (fc1_experts_bias_dims[0] != local_num_experts) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", - fc1_experts_bias_dims[0], - " and ", local_num_experts); + fc1_experts_bias_dims[0], " and ", local_num_experts); } if (fc2_experts_bias_dims[0] != num_experts) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[0] must be equal to num_experts, got ", - fc2_experts_bias_dims[0], + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], " and ", num_experts); } if (fc1_experts_bias_dims[1] != inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", - fc1_experts_bias_dims[1], + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], " and ", inter_size); } if (fc2_experts_bias_dims[1] != hidden_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", - fc2_experts_bias_dims[1], + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], " and ", hidden_size); } } @@ -137,10 +135,9 @@ class MoEBase { if (fc3_experts_bias_optional != nullptr && fc1_experts_bias_optional != nullptr && fc3_experts_bias_optional->Shape().GetDims() != fc1_experts_bias_optional->Shape().GetDims()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ", - fc3_experts_bias_optional->Shape().GetDims(), " and ", - fc1_experts_bias_optional->Shape().GetDims()); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, "fc3_experts_bias_dims must be equal to fc1_experts_bias_dims, got ", + fc3_experts_bias_optional->Shape().GetDims(), " and ", fc1_experts_bias_optional->Shape().GetDims()); } parameters.num_rows = num_rows; @@ -162,8 +159,47 @@ class MoEBase { } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_experts must be greater than or equal to local_num_experts, got ", - num_experts, " and ", local_num_experts); + "num_experts must be greater than or equal to local_num_experts, got ", num_experts, + " and ", local_num_experts); + } + + return Status::OK(); + } + + Status CheckInputScales(const Tensor* fc1_experts_scales, const Tensor* fc2_experts_scales, + const Tensor* fc3_experts_scales, int64_t num_experts, int64_t hidden_size, + int64_t inter_size) const { + const auto& fc1_experts_scales_dims = fc1_experts_scales->Shape().GetDims(); + const auto& fc2_experts_scales_dims = fc2_experts_scales->Shape().GetDims(); + + if (fc1_experts_scales_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales must be 2D, got ", + fc1_experts_scales->Shape().GetDims().size()); + } + if (fc1_experts_scales_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", + fc1_experts_scales_dims[0], " and ", num_experts); + } + if (fc1_experts_scales_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", + fc1_experts_scales_dims[1], " and ", inter_size); + } + if (fc2_experts_scales_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", + fc2_experts_scales->Shape().GetDims().size()); + } + if (fc2_experts_scales_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[0] must be equal to num_experts, got ", + fc2_experts_scales_dims[0], " and ", num_experts); + } + if (fc2_experts_scales_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales[1] must be equal to hidden_size, got ", + fc2_experts_scales_dims[1], " and ", hidden_size); + } + if (fc3_experts_scales != nullptr && fc1_experts_scales_dims != fc3_experts_scales->Shape().GetDims()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc3_experts_scales must be equal to fc1_experts_scales, got ", + fc3_experts_scales->Shape().GetDims(), " and ", fc1_experts_scales_dims); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc new file mode 100644 index 0000000000000..7bb0945615d37 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/quantization/moe_quantization.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", BuildKernelDefConstraints()) \ + .TypeConstraint("T1", BuildKernelDefConstraints()), \ + QMoE); + +REGISTER_KERNEL() + +namespace { +template +struct ToCudaTypeWrapper : public ToCudaType {}; + +template <> +struct ToCudaTypeWrapper { + using MappedType = uint8_t; +}; + +template <> +struct ToCudaTypeWrapper { + using MappedType = cutlass::uint4b_t; +}; +} // anonymous namespace + +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} + +Status QMoE::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc1_scales = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_weights = context->Input(5); + const Tensor* fc2_scales = context->Input(6); + const Tensor* fc2_experts_bias_optional = context->Input(7); + const Tensor* fc3_experts_weights_optional = context->Input(8); + const Tensor* fc3_scales_optional = context->Input(9); + const Tensor* fc3_experts_bias_optional = context->Input(10); + + MoEParameters moe_params; + MoEQuantType quant_type = MoEQuantType::UINT4; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); + ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, + moe_params.hidden_size, moe_params.inter_size)); + + // Support int4 only at the moment. We can add uint8 if needed. + static constexpr bool use_quint4x2 = true; + using T = MLFloat16; + using CudaT = typename ToCudaType::MappedType; + using CudaWeightT = typename ToCudaTypeWrapper::MappedType; + + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, + normalize_routing_weights_); + + size_t ws_size = moe_runner.getWorkspaceSize( + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); + size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + moe_runner.run_moe_fc( + reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->DataRaw()), + fc1_scales == nullptr ? nullptr : reinterpret_cast(fc1_scales->template Data()), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, + fc3_experts_weights_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_weights_optional->DataRaw()), + fc3_scales_optional == nullptr ? nullptr + : reinterpret_cast(fc3_scales_optional->template Data()), + fc3_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc3_experts_bias_optional->template Data()), + reinterpret_cast(fc2_experts_weights->DataRaw()), + fc2_scales == nullptr ? nullptr : reinterpret_cast(fc2_scales->template Data()), + static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), 0 /*local_experts_start_index_ used in sharded MoE*/, + static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h new file mode 100644 index 0000000000000..7b68d2d082de8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "contrib_ops/cuda/moe/moe_base.h" +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +class QMoE final : public CudaKernel, public MoEBase { + public: + explicit QMoE(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 82cc16acad582..f4d990540ed51 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1404,6 +1404,64 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +ONNX_MS_OPERATOR_SET_SCHEMA( + QMoE, 1, + OpSchema() + .SetDoc("Int4 MoE") + .Attr("activation_type", + "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + AttributeProto::STRING, + std::string("relu")) + .Attr("k", + "Number of top experts to select from expert pool", + AttributeProto::INT, + static_cast(1)) + .Attr("normalize_routing_weights", + "Whether to normalize routing weights", + AttributeProto::INT, + static_cast(0)) + .Input(0, + "input", + "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " + "(batch_size, sequence_length, hidden_size)", + "T") + .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size / 2)", "T1") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(4, + "fc1_experts_bias", + "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size / 2)", "T1") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(7, + "fc2_experts_bias", + "2D optional input tensor with shape (num_experts, hidden_size)", + "T", + OpSchema::Optional) + .Input(8, + "fc3_experts_weights", + "3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)", + "T1", + OpSchema::Optional) + .Input(9, + "fc3_scales", + "2D optional input tensor with shape (num_experts, inter_size)", + "T", + OpSchema::Optional) + .Input(10, + "fc3_experts_bias", + "2D optional input tensor with shape (num_experts, inter_size)", + "T", + OpSchema::Optional) + .Output(0, + "output", + "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " + "(batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 5eef1b33a24dd..ef86352080ff5 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -84,6 +84,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MaxpoolWithMask); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MoE); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); @@ -191,6 +192,7 @@ class OpSet_Microsoft_ver1 { #endif fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 263ace25ddfe0..7dbaadd51d14a 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -9,24 +9,14 @@ namespace onnxruntime { namespace test { -static void RunMoETest( - const std::vector& input, - const std::vector& router_probs, - const std::vector& fc1_experts_weights, - const std::vector& fc2_experts_weights, - const std::vector& fc3_experts_weights, - const std::vector& fc1_experts_bias, - const std::vector& fc2_experts_bias, - const std::vector& output_data, - int num_rows, - int num_experts, - int hidden_size, - int inter_size, - std::string activation_type, - int normalize_routing_weights = 0, - int top_k = 1, - bool use_float16 = false) { - int min_cuda_architecture = use_float16 ? 530 : 0; +#ifndef ENABLE_TRAINING +static void RunMoETest(const std::vector& input, const std::vector& router_probs, + const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, + const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, + const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, + int num_experts, int hidden_size, int inter_size, std::string activation_type, + int normalize_routing_weights = 0, int top_k = 1, bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 700 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); if (enable_cuda) { @@ -92,6 +82,52 @@ static void RunMoETest( } } +// TODO(wy): Add python parity tests that can serve as examples. Need cutlass upgrade to build cutlass extensions to +// add weights preprocesser to onnxruntime_pybind_quant.cc +static void RunQMoETest(const std::vector& input, const std::vector& router_probs, + const std::vector& fc1_experts_weights, + const std::vector& fc2_experts_weights, + const std::vector& fc3_experts_weights, const std::vector& fc1_scales, + const std::vector& fc2_scales, const std::vector& fc3_scales, + const std::vector& output_data, int num_rows, int num_experts, int hidden_size, + int inter_size, std::string activation_type, int normalize_routing_weights = 0, int top_k = 1) { + bool enable_cuda = HasCudaEnvironment(700); + if (enable_cuda) { + OpTester tester("QMoE", 1, onnxruntime::kMSDomain); + tester.AddAttribute("k", static_cast(top_k)); + tester.AddAttribute("activation_type", activation_type); + tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, inter_size / 2}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size / 2}; + std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; + std::vector fc1_scales_dims = {num_experts, inter_size}; + std::vector fc2_scales_dims = {num_experts, hidden_size}; + std::vector fc3_scales_dims = fc1_scales_dims; + std::vector output_dims = {num_rows, hidden_size}; + + tester.AddInput("input", input_dims, ToFloat16(input)); + tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); + tester.AddOptionalInputEdge(); // fc1_experts_bias + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); + tester.AddOptionalInputEdge(); // fc2_experts_bias + tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + tester.AddInput("fc3_scales", fc3_scales_dims, ToFloat16(fc3_scales)); + tester.AddOutput("output", output_dims, ToFloat16(output_data)); + tester.SetOutputTolerance(0.005f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MoETest, MoETest_Gelu) { int num_rows = 4; int num_experts = 4; @@ -107,135 +143,145 @@ TEST(MoETest, MoETest_Gelu) { -0.84837115f, 0.100507565f, -0.10548311f, 0.40957215f, 1.0159845f, 0.26919764f, 0.021741152f, -0.34184334f, -0.71324956f, 0.29018253f, -0.18227568f, 0.31496462f, -0.48426327f, -1.006643f, -0.100081146f, -0.07692295f}; const std::vector fc1_experts_weights = { - 0.14731085f, 0.52229995f, 0.14753294f, 0.22475791f, 0.20864725f, 0.6708725f, 0.20204341f, 0.4890914f, - 0.52103406f, 0.8223115f, 0.122039974f, 0.15674388f, 0.20966923f, 0.8499667f, 0.3202675f, 0.92174435f, - 0.6808038f, 0.563313f, 0.496278f, 0.40115923f, 0.5627332f, 0.38582766f, 0.49648678f, 0.5637965f, - 0.10889745f, 0.23793429f, 0.90374637f, 0.09422666f, 0.4640969f, 0.99461937f, 0.6806185f, 0.5141565f, - 0.066695035f, 0.74768895f, 0.14385962f, 0.35806787f, 0.33224183f, 0.4259563f, 0.50546914f, 0.91240376f, - 0.5624194f, 0.9478464f, 0.8058562f, 0.18389302f, 0.72425205f, 0.14655197f, 0.28808743f, 0.64706135f, - 0.66509604f, 0.875114f, 0.33904207f, 0.50080043f, 0.7574118f, 0.016453922f, 0.8614903f, 0.08653879f, - 0.50689125f, 0.41499162f, 0.23666352f, 0.5660855f, 0.91345936f, 0.35384023f, 0.20315295f, 0.31508058f, - 0.0044258237f, 0.725697f, 0.25986814f, 0.16632986f, 0.21194929f, 0.787478f, 0.76478684f, 0.8837609f, - 0.68136156f, 0.33302015f, 0.36027592f, 0.647715f, 0.91101736f, 0.6359461f, 0.26342732f, 0.2649613f, - 0.02726549f, 0.608024f, 0.21940875f, 0.054212093f, 0.93843824f, 0.1752944f, 0.44311923f, 0.64324677f, - 0.51592916f, 0.16355914f, 0.09583914f, 0.8985412f, 0.58141935f, 0.91481227f, 0.3323797f, 0.6472777f, - 0.3856619f, 0.47776443f, 0.1954779f, 0.66910046f, 0.65808296f, 0.4896857f, 0.38754892f, 0.1917851f, - 0.8457724f, 0.12778795f, 0.70483273f, 0.33187324f, 0.258766f, 0.58982253f, 0.24027151f, 0.6152024f, - 0.5981904f, 0.12875527f, 0.5832493f, 0.7129646f, 0.6979155f, 0.43706065f, 0.09010619f, 0.42292297f, - 0.67365384f, 0.31756145f, 0.68979055f, 0.8329813f, 0.2389242f, 0.5049309f, 0.7067495f, 0.5391889f, - 0.54176575f, 0.5624327f, 0.10692614f, 0.5392941f, 0.8462349f, 0.9505569f, 0.79387546f, 0.5670015f, - 0.7335071f, 0.25676018f, 0.08565581f, 0.07003945f, 0.99880487f, 0.8173947f, 0.15438312f, 0.6956213f, - 0.8775838f, 0.9998074f, 0.93719745f, 0.8873769f, 0.38537037f, 0.32452917f, 0.9105244f, 0.7801898f, - 0.19911051f, 0.9495086f, 0.7415793f, 0.77256775f, 0.18661183f, 0.6434499f, 0.32471877f, 0.8906783f, - 0.4100297f, 0.69465625f, 0.5888109f, 0.7127341f, 0.33008623f, 0.7437857f, 0.15076452f, 0.6129275f, - 0.16170406f, 0.006731212f, 0.09847212f, 0.89473504f, 0.7705178f, 0.96910787f, 0.9005606f, 0.053477287f, - 0.15878445f, 0.4192087f, 0.17528385f, 0.84719825f, 0.121996105f, 0.25604928f, 0.016954303f, 0.21612722f, - 0.91123873f, 0.90938f, 0.85791886f, 0.88606364f, 0.94459325f, 0.3719685f, 0.72000104f, 0.9454652f, - 0.6654094f, 0.9998382f, 0.75933146f, 0.81082416f, 0.32500392f, 0.73991376f, 0.5574533f, 0.38059133f, - 0.21814507f, 0.21944171f, 0.11525959f, 0.83566517f, 0.8554656f, 0.44309366f, 0.210657f, 0.88645273f, - 0.81974447f, 0.537167f, 0.26393235f, 0.9595239f, 0.70447034f, 0.12042731f, 0.97854143f, 0.8796869f, - 0.31775457f, 0.78107727f, 0.21590549f, 0.42164284f, 0.9245506f, 0.52065957f, 0.14639091f, 0.33288354f, - 0.36427742f, 0.4035356f, 0.5478503f, 0.9624148f, 0.5267702f, 0.19128f, 0.52562714f, 0.7397436f, - 0.7480201f, 0.04303074f, 0.41052878f, 0.12842774f, 0.2866572f, 0.6801467f, 0.1449349f, 0.68586344f, - 0.92438906f, 0.5327942f, 0.16675615f, 0.32085752f, 0.60918206f, 0.11884099f, 0.74840516f, 0.04606521f, - 0.01935333f, 0.014169693f, 0.39856833f, 0.83621645f, 0.026760519f, 0.91559356f, 0.29998857f, 0.64644206f, - 0.52280146f, 0.049140453f, 0.9146645f, 0.7692217f, 0.99699783f, 0.7526061f, 0.1699655f, 0.9172919f, - 0.5268722f, 0.73710823f, 0.09908545f, 0.35618675f, 0.009061217f, 0.30525374f, 0.6078656f, 0.10741913f, - 0.6593821f, 0.7684034f, 0.56965464f, 0.16545832f, 0.11234015f, 0.3457417f, 0.7194791f, 0.9931982f, - 0.7875145f, 0.44369537f, 0.6753082f, 0.009468555f, 0.07294935f, 0.73330396f, 0.2167924f, 0.74054784f, - 0.14703393f, 0.25234455f, 0.08815551f, 0.76092035f, 0.44905245f, 0.88480055f, 0.8094361f, 0.7766713f, - 0.51607805f, 0.345411f, 0.39128417f, 0.5664503f, 0.74785477f, 0.14970505f, 0.91963893f, 0.44563496f, - 0.08102721f, 0.22947109f, 0.94240886f, 0.9572636f, 0.036860168f, 0.85264915f, 0.7505796f, 0.79595923f, - 0.9232646f, 0.23052484f, 0.6578879f, 0.7046166f, 0.35225332f, 0.66732657f, 0.3561433f, 0.80913067f, - 0.3612727f, 0.31360215f, 0.6258745f, 0.6773468f, 0.25571418f, 0.54419917f, 0.78976786f, 0.45025164f, - 0.65216696f, 0.3794065f, 0.6752498f, 0.1378029f, 0.2059856f, 0.24620473f, 0.95950544f, 0.36545795f, - 0.49863482f, 0.25775224f, 0.99914503f, 0.9883351f, 0.122906685f, 0.09466505f, 0.12100351f, 0.49758863f, - 0.37254804f, 0.17272717f, 0.32066393f, 0.59446543f, 0.23875463f, 0.61079127f, 0.38534206f, 0.25771832f, - 0.56869274f, 0.9111291f, 0.16196036f, 0.5232172f, 0.31561613f, 0.99065316f, 0.025618374f, 0.0206694f, - 0.9926925f, 0.18365502f, 0.5958617f, 0.45684695f, 0.3946715f, 0.3883261f, 0.8177203f, 0.5238985f, - 0.013192713f, 0.20481992f, 0.32954985f, 0.7516082f, 0.17643315f, 0.9714598f, 0.38863534f, 0.410219f, - 0.891779f, 0.75130385f, 0.92406017f, 0.7892222f, 0.34832305f, 0.1682638f, 0.46279848f, 0.9138188f, - 0.3321901f, 0.036315024f, 0.7049642f, 0.9867357f, 0.3576584f, 0.08598822f, 0.046470165f, 0.6252997f, - 0.46214014f, 0.24750638f, 0.60106593f, 0.6898794f, 0.8976595f, 0.8881911f, 0.42515814f, 0.059116423f, - 0.048188448f, 0.9668448f, 0.7210276f, 0.7179537f, 0.06738949f, 0.96300787f, 0.97367156f, 0.95143014f, - 0.07820749f, 0.3113383f, 0.1561181f, 0.9734828f, 0.28516f, 0.27172273f, 0.76195645f, 0.26870382f, - 0.25373894f, 0.45626426f, 0.45194024f, 0.11051077f, 0.91683406f, 0.27943915f, 0.67735744f, 0.9348918f, - 0.7521582f, 0.57078993f, 0.9254285f, 0.5672131f, 0.2686717f, 0.97299975f, 0.61834025f, 0.012159586f, - 0.3576542f, 0.15941626f, 0.9383765f, 0.41742706f, 0.044237554f, 0.46856833f, 0.81400645f, 0.6299002f, - 0.6581022f, 0.5464366f, 0.68640935f, 0.378174f, 0.3010999f, 0.032645762f, 0.12333155f, 0.71670127f, - 0.20394331f, 0.57173324f, 0.6595957f, 0.53540194f, 0.17582512f, 0.9781642f, 0.20925027f, 0.9112503f, - 0.10224587f, 0.37972575f, 0.7719844f, 0.29570967f, 0.9200215f, 0.15592176f, 0.080114245f, 0.27454042f, - 0.5808252f, 0.96037793f, 0.26129955f, 0.6788141f, 0.37464648f, 0.39156884f, 0.8676517f, 0.112507045f, - 0.55310667f, 0.9702046f, 0.4312939f, 0.88821906f, 0.3460216f, 0.9024811f, 0.016334832f, 0.42793816f, - 0.4121768f, 0.6620425f, 0.6961637f, 0.88390845f, 0.425507f, 0.48017246f, 0.8424056f, 0.36471343f, - 0.9383168f, 0.16709393f, 0.44589508f, 0.47314453f, 0.72310495f, 0.84183806f, 0.4207481f, 0.0857597f, - 0.7477461f, 0.6495659f, 0.70084965f, 0.19156617f, 0.8217978f, 0.9735775f, 0.5433857f, 0.032975793f, - 0.85099494f, 0.12927437f, 0.61493605f, 0.5726589f, 0.26598173f, 0.6740978f, 0.052783668f, 0.61387974f}; + 0.14731085f, 0.6808038f, 0.066695035f, 0.66509604f, 0.0044258237f, 0.02726549f, 0.3856619f, 0.5981904f, + 0.52229995f, 0.563313f, 0.74768895f, 0.875114f, 0.725697f, 0.608024f, 0.47776443f, 0.12875527f, + 0.14753294f, 0.496278f, 0.14385962f, 0.33904207f, 0.25986814f, 0.21940875f, 0.1954779f, 0.5832493f, + 0.22475791f, 0.40115923f, 0.35806787f, 0.50080043f, 0.16632986f, 0.054212093f, 0.66910046f, 0.7129646f, + 0.20864725f, 0.5627332f, 0.33224183f, 0.7574118f, 0.21194929f, 0.93843824f, 0.65808296f, 0.6979155f, + 0.6708725f, 0.38582766f, 0.4259563f, 0.016453922f, 0.787478f, 0.1752944f, 0.4896857f, 0.43706065f, + 0.20204341f, 0.49648678f, 0.50546914f, 0.8614903f, 0.76478684f, 0.44311923f, 0.38754892f, 0.09010619f, + 0.4890914f, 0.5637965f, 0.91240376f, 0.08653879f, 0.8837609f, 0.64324677f, 0.1917851f, 0.42292297f, + 0.52103406f, 0.10889745f, 0.5624194f, 0.50689125f, 0.68136156f, 0.51592916f, 0.8457724f, 0.67365384f, + 0.8223115f, 0.23793429f, 0.9478464f, 0.41499162f, 0.33302015f, 0.16355914f, 0.12778795f, 0.31756145f, + 0.122039974f, 0.90374637f, 0.8058562f, 0.23666352f, 0.36027592f, 0.09583914f, 0.70483273f, 0.68979055f, + 0.15674388f, 0.09422666f, 0.18389302f, 0.5660855f, 0.647715f, 0.8985412f, 0.33187324f, 0.8329813f, + 0.20966923f, 0.4640969f, 0.72425205f, 0.91345936f, 0.91101736f, 0.58141935f, 0.258766f, 0.2389242f, + 0.8499667f, 0.99461937f, 0.14655197f, 0.35384023f, 0.6359461f, 0.91481227f, 0.58982253f, 0.5049309f, + 0.3202675f, 0.6806185f, 0.28808743f, 0.20315295f, 0.26342732f, 0.3323797f, 0.24027151f, 0.7067495f, + 0.92174435f, 0.5141565f, 0.64706135f, 0.31508058f, 0.2649613f, 0.6472777f, 0.6152024f, 0.5391889f, + 0.54176575f, 0.8775838f, 0.4100297f, 0.15878445f, 0.6654094f, 0.81974447f, 0.36427742f, 0.92438906f, + 0.5624327f, 0.9998074f, 0.69465625f, 0.4192087f, 0.9998382f, 0.537167f, 0.4035356f, 0.5327942f, + 0.10692614f, 0.93719745f, 0.5888109f, 0.17528385f, 0.75933146f, 0.26393235f, 0.5478503f, 0.16675615f, + 0.5392941f, 0.8873769f, 0.7127341f, 0.84719825f, 0.81082416f, 0.9595239f, 0.9624148f, 0.32085752f, + 0.8462349f, 0.38537037f, 0.33008623f, 0.121996105f, 0.32500392f, 0.70447034f, 0.5267702f, 0.60918206f, + 0.9505569f, 0.32452917f, 0.7437857f, 0.25604928f, 0.73991376f, 0.12042731f, 0.19128f, 0.11884099f, + 0.79387546f, 0.9105244f, 0.15076452f, 0.016954303f, 0.5574533f, 0.97854143f, 0.52562714f, 0.74840516f, + 0.5670015f, 0.7801898f, 0.6129275f, 0.21612722f, 0.38059133f, 0.8796869f, 0.7397436f, 0.04606521f, + 0.7335071f, 0.19911051f, 0.16170406f, 0.91123873f, 0.21814507f, 0.31775457f, 0.7480201f, 0.01935333f, + 0.25676018f, 0.9495086f, 0.006731212f, 0.90938f, 0.21944171f, 0.78107727f, 0.04303074f, 0.014169693f, + 0.08565581f, 0.7415793f, 0.09847212f, 0.85791886f, 0.11525959f, 0.21590549f, 0.41052878f, 0.39856833f, + 0.07003945f, 0.77256775f, 0.89473504f, 0.88606364f, 0.83566517f, 0.42164284f, 0.12842774f, 0.83621645f, + 0.99880487f, 0.18661183f, 0.7705178f, 0.94459325f, 0.8554656f, 0.9245506f, 0.2866572f, 0.026760519f, + 0.8173947f, 0.6434499f, 0.96910787f, 0.3719685f, 0.44309366f, 0.52065957f, 0.6801467f, 0.91559356f, + 0.15438312f, 0.32471877f, 0.9005606f, 0.72000104f, 0.210657f, 0.14639091f, 0.1449349f, 0.29998857f, + 0.6956213f, 0.8906783f, 0.053477287f, 0.9454652f, 0.88645273f, 0.33288354f, 0.68586344f, 0.64644206f, + 0.52280146f, 0.6593821f, 0.14703393f, 0.08102721f, 0.3612727f, 0.49863482f, 0.56869274f, 0.013192713f, + 0.049140453f, 0.7684034f, 0.25234455f, 0.22947109f, 0.31360215f, 0.25775224f, 0.9111291f, 0.20481992f, + 0.9146645f, 0.56965464f, 0.08815551f, 0.94240886f, 0.6258745f, 0.99914503f, 0.16196036f, 0.32954985f, + 0.7692217f, 0.16545832f, 0.76092035f, 0.9572636f, 0.6773468f, 0.9883351f, 0.5232172f, 0.7516082f, + 0.99699783f, 0.11234015f, 0.44905245f, 0.036860168f, 0.25571418f, 0.122906685f, 0.31561613f, 0.17643315f, + 0.7526061f, 0.3457417f, 0.88480055f, 0.85264915f, 0.54419917f, 0.09466505f, 0.99065316f, 0.9714598f, + 0.1699655f, 0.7194791f, 0.8094361f, 0.7505796f, 0.78976786f, 0.12100351f, 0.025618374f, 0.38863534f, + 0.9172919f, 0.9931982f, 0.7766713f, 0.79595923f, 0.45025164f, 0.49758863f, 0.0206694f, 0.410219f, + 0.5268722f, 0.7875145f, 0.51607805f, 0.9232646f, 0.65216696f, 0.37254804f, 0.9926925f, 0.891779f, + 0.73710823f, 0.44369537f, 0.345411f, 0.23052484f, 0.3794065f, 0.17272717f, 0.18365502f, 0.75130385f, + 0.09908545f, 0.6753082f, 0.39128417f, 0.6578879f, 0.6752498f, 0.32066393f, 0.5958617f, 0.92406017f, + 0.35618675f, 0.009468555f, 0.5664503f, 0.7046166f, 0.1378029f, 0.59446543f, 0.45684695f, 0.7892222f, + 0.009061217f, 0.07294935f, 0.74785477f, 0.35225332f, 0.2059856f, 0.23875463f, 0.3946715f, 0.34832305f, + 0.30525374f, 0.73330396f, 0.14970505f, 0.66732657f, 0.24620473f, 0.61079127f, 0.3883261f, 0.1682638f, + 0.6078656f, 0.2167924f, 0.91963893f, 0.3561433f, 0.95950544f, 0.38534206f, 0.8177203f, 0.46279848f, + 0.10741913f, 0.74054784f, 0.44563496f, 0.80913067f, 0.36545795f, 0.25771832f, 0.5238985f, 0.9138188f, + 0.3321901f, 0.048188448f, 0.25373894f, 0.3576542f, 0.20394331f, 0.5808252f, 0.4121768f, 0.7477461f, + 0.036315024f, 0.9668448f, 0.45626426f, 0.15941626f, 0.57173324f, 0.96037793f, 0.6620425f, 0.6495659f, + 0.7049642f, 0.7210276f, 0.45194024f, 0.9383765f, 0.6595957f, 0.26129955f, 0.6961637f, 0.70084965f, + 0.9867357f, 0.7179537f, 0.11051077f, 0.41742706f, 0.53540194f, 0.6788141f, 0.88390845f, 0.19156617f, + 0.3576584f, 0.06738949f, 0.91683406f, 0.044237554f, 0.17582512f, 0.37464648f, 0.425507f, 0.8217978f, + 0.08598822f, 0.96300787f, 0.27943915f, 0.46856833f, 0.9781642f, 0.39156884f, 0.48017246f, 0.9735775f, + 0.046470165f, 0.97367156f, 0.67735744f, 0.81400645f, 0.20925027f, 0.8676517f, 0.8424056f, 0.5433857f, + 0.6252997f, 0.95143014f, 0.9348918f, 0.6299002f, 0.9112503f, 0.112507045f, 0.36471343f, 0.032975793f, + 0.46214014f, 0.07820749f, 0.7521582f, 0.6581022f, 0.10224587f, 0.55310667f, 0.9383168f, 0.85099494f, + 0.24750638f, 0.3113383f, 0.57078993f, 0.5464366f, 0.37972575f, 0.9702046f, 0.16709393f, 0.12927437f, + 0.60106593f, 0.1561181f, 0.9254285f, 0.68640935f, 0.7719844f, 0.4312939f, 0.44589508f, 0.61493605f, + 0.6898794f, 0.9734828f, 0.5672131f, 0.378174f, 0.29570967f, 0.88821906f, 0.47314453f, 0.5726589f, + 0.8976595f, 0.28516f, 0.2686717f, 0.3010999f, 0.9200215f, 0.3460216f, 0.72310495f, 0.26598173f, + 0.8881911f, 0.27172273f, 0.97299975f, 0.032645762f, 0.15592176f, 0.9024811f, 0.84183806f, 0.6740978f, + 0.42515814f, 0.76195645f, 0.61834025f, 0.12333155f, 0.080114245f, 0.016334832f, 0.4207481f, 0.052783668f, + 0.059116423f, 0.26870382f, 0.012159586f, 0.71670127f, 0.27454042f, 0.42793816f, 0.0857597f, 0.61387974f}; const std::vector fc2_experts_weights = { - 0.18302453f, 0.44593316f, 0.5643144f, 0.9259722f, 0.26143986f, 0.82031804f, 0.4364831f, 0.2625361f, - 0.06460017f, 0.04124081f, 0.98830533f, 0.37530023f, 0.5249744f, 0.63555616f, 0.8398661f, 0.92673707f, - 0.9055086f, 0.12955844f, 0.4198916f, 0.20413119f, 0.21432412f, 0.6186035f, 0.969324f, 0.099448025f, - 0.80260223f, 0.24076664f, 0.40261286f, 0.89688545f, 0.38691485f, 0.5455279f, 0.15048373f, 0.92562044f, - 0.43536508f, 0.13430476f, 0.64640516f, 0.14449131f, 0.10324633f, 0.5304596f, 0.8964218f, 0.358508f, - 0.73533344f, 0.9296606f, 0.83163047f, 0.23771948f, 0.44519007f, 0.34265757f, 0.09793854f, 0.5002066f, - 0.87621754f, 0.9212578f, 0.54665035f, 0.6135615f, 0.28353918f, 0.8774212f, 0.29194576f, 0.1526736f, - 0.57699674f, 0.7996927f, 0.04920423f, 0.95198375f, 0.67986554f, 0.14969361f, 0.39229625f, 0.93378997f, - 0.11638266f, 0.3538614f, 0.66399014f, 0.06195748f, 0.7740991f, 0.7602738f, 0.81010276f, 0.18122643f, - 0.9980005f, 0.20361924f, 0.99917024f, 0.020154774f, 0.054515004f, 0.80709815f, 0.55225646f, 0.52884465f, - 0.22312081f, 0.29026228f, 0.35380626f, 0.012922287f, 0.52598435f, 0.58842945f, 0.4995767f, 0.66146517f, - 0.9744255f, 0.632942f, 0.3169638f, 0.29422665f, 0.18009722f, 0.15339059f, 0.41947508f, 0.4115672f, - 0.72243124f, 0.2862816f, 0.89860183f, 0.14915991f, 0.5014211f, 0.94945997f, 0.99719256f, 0.21036887f, - 0.5890645f, 0.55906135f, 0.26557416f, 0.32725257f, 0.635427f, 0.1523174f, 0.58249784f, 0.71636236f, - 0.30296493f, 0.9153206f, 0.46709478f, 0.72685635f, 0.9951532f, 0.34716582f, 0.7717041f, 0.3569854f, - 0.4269635f, 0.41526443f, 0.4968937f, 0.3111158f, 0.61719346f, 0.5188402f, 0.8169449f, 0.39879733f, - 0.5501401f, 0.31400484f, 0.08127314f, 0.7023336f, 0.56397897f, 0.29975814f, 0.33094752f, 0.63076067f, - 0.40959156f, 0.82673794f, 0.52832156f, 0.68886834f, 0.7178481f, 0.37731683f, 0.71633244f, 0.86896664f, - 0.5230092f, 0.59784645f, 0.5181678f, 0.8461837f, 0.28890234f, 0.23421508f, 0.7178768f, 0.06484294f, - 0.5080162f, 0.27005446f, 0.8300168f, 0.034480453f, 0.8031663f, 0.9946784f, 0.60117006f, 0.46668667f, - 0.9921749f, 0.28632385f, 0.45993322f, 0.28104752f, 0.43097937f, 0.60866946f, 0.5667807f, 0.40556252f, - 7.969141e-05f, 0.52560204f, 0.48518902f, 0.5752184f, 0.8831251f, 0.9860047f, 0.20335877f, 0.46882278f, - 0.2996632f, 0.03917718f, 0.13617045f, 0.96928054f, 0.79153055f, 0.76857555f, 0.7778716f, 0.102760494f, - 0.5525096f, 0.9653573f, 0.22095704f, 0.94479716f, 0.63141924f, 0.8517718f, 0.28580618f, 0.73050886f, - 0.05675614f, 0.46825224f, 0.6667756f, 0.6499472f, 0.91840404f, 0.99132854f, 0.9548785f, 0.8356961f, - 0.851531f, 0.43548512f, 0.111976564f, 0.31438643f, 0.44386774f, 0.22980672f, 0.75558543f, 0.6755136f, - 0.58067596f, 0.62078035f, 0.93922615f, 0.6821157f, 0.061530292f, 0.13705963f, 0.7203748f, 0.5681396f, - 0.7438458f, 0.0006400347f, 0.038565338f, 0.8066132f, 0.81982285f, 0.047644496f, 0.68979263f, 0.109577894f, - 0.8786539f, 0.6568952f, 0.99439347f, 0.0070040226f, 0.018661916f, 0.838051f, 0.94391155f, 0.80634f, - 0.8324149f, 0.078864336f, 0.8619068f, 0.027926445f, 0.61170083f, 0.17248261f, 0.30140227f, 0.5885344f, - 0.30341f, 0.42088854f, 0.02608782f, 0.02856338f, 0.69368154f, 0.28836077f, 0.19580519f, 0.30270886f, - 0.09121573f, 0.100299895f, 0.79918617f, 0.75412107f, 0.56660175f, 0.22687018f, 0.6663505f, 0.5224626f, - 0.1426636f, 0.6075949f, 0.95527196f, 0.008196831f, 0.0028039217f, 0.5640625f, 0.87651116f, 0.19575512f, - 0.61006856f, 0.85149264f, 0.6541582f, 0.6082054f, 0.998863f, 0.82573634f, 0.21878648f, 0.54321826f, - 0.7554362f, 0.94095474f, 0.002533555f, 0.77075267f, 0.35483408f, 0.010389388f, 0.610987f, 0.22779316f, - 0.5708561f, 0.17537653f, 0.12373549f, 0.4575745f, 0.33203715f, 0.79243237f, 0.54310906f, 0.8902793f, - 0.5937015f, 0.33921933f, 0.8386668f, 0.52732253f, 0.59384584f, 0.3391887f, 0.5017944f, 0.40386343f, - 0.45749134f, 0.110060334f, 0.49692506f, 0.084977865f, 0.3924346f, 0.7897731f, 0.15232486f, 0.16297412f, - 0.37791175f, 0.36293298f, 0.5846437f, 0.5830078f, 0.75354826f, 0.15555972f, 0.4647144f, 0.7796456f, - 0.93248576f, 0.46352726f, 0.2106899f, 0.6437313f, 0.78473866f, 0.18762505f, 0.20985329f, 0.7209991f, - 0.464967f, 0.02775067f, 0.21170747f, 0.7027664f, 0.33041215f, 0.8451145f, 0.89526993f, 0.57273495f, - 0.46046263f, 0.34128642f, 0.47471708f, 0.59101045f, 0.11807448f, 0.38050216f, 0.08409953f, 0.80687743f, - 0.18158185f, 0.9567719f, 0.3711096f, 0.21356237f, 0.74022657f, 0.57453954f, 0.846228f, 0.70873487f, - 0.018330276f, 0.8162452f, 0.40584308f, 0.27901447f, 0.81752694f, 0.86466515f, 0.060534656f, 0.45478833f, - 0.9106033f, 0.6936434f, 0.92123467f, 0.32865065f, 0.22417879f, 0.9299548f, 0.70841146f, 0.97999126f, - 0.2911517f, 0.17896658f, 0.44139355f, 0.029210031f, 0.6959876f, 0.8687942f, 0.62002844f, 0.45059657f, - 0.74790317f, 0.18262434f, 0.98912156f, 0.0028281808f, 0.021027386f, 0.38184917f, 0.90842223f, 0.5500629f, - 0.69202286f, 0.13349658f, 0.6823429f, 0.44412827f, 0.7004118f, 0.8531213f, 0.7173401f, 0.4574679f, - 0.46920043f, 0.18640989f, 0.31914896f, 0.82491904f, 0.29950172f, 0.8105199f, 0.30173403f, 0.38355058f, - 0.5106411f, 0.04116726f, 0.49500751f, 0.44960213f, 0.45508182f, 0.4000479f, 0.89418864f, 0.8689936f, - 0.16112137f, 0.7322634f, 0.10780871f, 0.07433933f, 0.652841f, 0.50734824f, 0.26674682f, 0.017748117f, - 0.30643195f, 0.66699976f, 0.03719926f, 0.014267266f, 0.56343627f, 0.13979793f, 0.061959863f, 0.3073569f, - 0.41949958f, 0.045647383f, 0.16613615f, 0.5327839f, 0.028514147f, 0.4297228f, 0.17714864f, 0.15338135f, - 0.6965155f, 0.11515516f, 0.1210829f, 0.78514075f, 0.59348315f, 0.9553564f, 0.36635226f, 0.25849247f, - 0.45372677f, 0.5025297f, 0.88132215f, 0.0019600391f, 0.46439964f, 0.7211761f, 0.22465849f, 0.2459296f, - 0.7416339f, 0.020907402f, 0.6184779f, 0.112906754f, 0.7485309f, 0.072479784f, 0.8074024f, 0.026683688f, - 0.07971662f, 0.50736845f, 0.8939942f, 0.0718022f, 0.27697015f, 0.9391413f, 0.4161513f, 0.7071423f, - 0.019000888f, 0.34275955f, 0.24608392f, 0.9215306f, 0.70751995f, 0.13516217f, 0.5806135f, 0.49425328f, - 0.29456508f, 0.21446168f, 0.3340807f, 0.89411324f, 0.14157385f, 0.14382833f, 0.34574044f, 0.50869817f, - 0.63610595f, 0.51500404f, 0.37963718f, 0.19682491f, 0.41028368f, 0.29872334f, 0.9039644f, 0.013295233f, - 0.1810705f, 0.093204916f, 0.4086216f, 0.8896367f, 0.9382696f, 0.06472236f, 0.47833657f, 0.7934831f, - 0.7203987f, 0.9095519f, 0.4861309f, 0.16405362f, 0.83076525f, 0.3285427f, 0.7588931f, 0.37678176f, - 0.71254706f, 0.949713f, 0.96492773f, 0.044967473f, 0.16925985f, 0.2932666f, 0.18114948f, 0.97975004f, - 0.4558406f, 0.16832972f, 0.27750528f, 0.2238177f, 0.7039947f, 0.06387442f, 0.033798456f, 0.007119417f}; + 0.18302453f, 0.06460017f, 0.9055086f, 0.80260223f, 0.43536508f, 0.73533344f, 0.87621754f, + 0.57699674f, 0.11638266f, 0.9980005f, 0.22312081f, 0.9744255f, 0.72243124f, 0.5890645f, + 0.30296493f, 0.4269635f, 0.44593316f, 0.04124081f, 0.12955844f, 0.24076664f, 0.13430476f, + 0.9296606f, 0.9212578f, 0.7996927f, 0.3538614f, 0.20361924f, 0.29026228f, 0.632942f, + 0.2862816f, 0.55906135f, 0.9153206f, 0.41526443f, 0.5643144f, 0.98830533f, 0.4198916f, + 0.40261286f, 0.64640516f, 0.83163047f, 0.54665035f, 0.04920423f, 0.66399014f, 0.99917024f, + 0.35380626f, 0.3169638f, 0.89860183f, 0.26557416f, 0.46709478f, 0.4968937f, 0.9259722f, + 0.37530023f, 0.20413119f, 0.89688545f, 0.14449131f, 0.23771948f, 0.6135615f, 0.95198375f, + 0.06195748f, 0.020154774f, 0.012922287f, 0.29422665f, 0.14915991f, 0.32725257f, 0.72685635f, + 0.3111158f, 0.26143986f, 0.5249744f, 0.21432412f, 0.38691485f, 0.10324633f, 0.44519007f, + 0.28353918f, 0.67986554f, 0.7740991f, 0.054515004f, 0.52598435f, 0.18009722f, 0.5014211f, + 0.635427f, 0.9951532f, 0.61719346f, 0.82031804f, 0.63555616f, 0.6186035f, 0.5455279f, + 0.5304596f, 0.34265757f, 0.8774212f, 0.14969361f, 0.7602738f, 0.80709815f, 0.58842945f, + 0.15339059f, 0.94945997f, 0.1523174f, 0.34716582f, 0.5188402f, 0.4364831f, 0.8398661f, + 0.969324f, 0.15048373f, 0.8964218f, 0.09793854f, 0.29194576f, 0.39229625f, 0.81010276f, + 0.55225646f, 0.4995767f, 0.41947508f, 0.99719256f, 0.58249784f, 0.7717041f, 0.8169449f, + 0.2625361f, 0.92673707f, 0.099448025f, 0.92562044f, 0.358508f, 0.5002066f, 0.1526736f, + 0.93378997f, 0.18122643f, 0.52884465f, 0.66146517f, 0.4115672f, 0.21036887f, 0.71636236f, + 0.3569854f, 0.39879733f, 0.5501401f, 0.40959156f, 0.5230092f, 0.5080162f, 0.9921749f, + 7.969141e-05f, 0.2996632f, 0.5525096f, 0.05675614f, 0.851531f, 0.58067596f, 0.7438458f, + 0.8786539f, 0.8324149f, 0.30341f, 0.09121573f, 0.31400484f, 0.82673794f, 0.59784645f, + 0.27005446f, 0.28632385f, 0.52560204f, 0.03917718f, 0.9653573f, 0.46825224f, 0.43548512f, + 0.62078035f, 0.0006400347f, 0.6568952f, 0.078864336f, 0.42088854f, 0.100299895f, 0.08127314f, + 0.52832156f, 0.5181678f, 0.8300168f, 0.45993322f, 0.48518902f, 0.13617045f, 0.22095704f, + 0.6667756f, 0.111976564f, 0.93922615f, 0.038565338f, 0.99439347f, 0.8619068f, 0.02608782f, + 0.79918617f, 0.7023336f, 0.68886834f, 0.8461837f, 0.034480453f, 0.28104752f, 0.5752184f, + 0.96928054f, 0.94479716f, 0.6499472f, 0.31438643f, 0.6821157f, 0.8066132f, 0.0070040226f, + 0.027926445f, 0.02856338f, 0.75412107f, 0.56397897f, 0.7178481f, 0.28890234f, 0.8031663f, + 0.43097937f, 0.8831251f, 0.79153055f, 0.63141924f, 0.91840404f, 0.44386774f, 0.061530292f, + 0.81982285f, 0.018661916f, 0.61170083f, 0.69368154f, 0.56660175f, 0.29975814f, 0.37731683f, + 0.23421508f, 0.9946784f, 0.60866946f, 0.9860047f, 0.76857555f, 0.8517718f, 0.99132854f, + 0.22980672f, 0.13705963f, 0.047644496f, 0.838051f, 0.17248261f, 0.28836077f, 0.22687018f, + 0.33094752f, 0.71633244f, 0.7178768f, 0.60117006f, 0.5667807f, 0.20335877f, 0.7778716f, + 0.28580618f, 0.9548785f, 0.75558543f, 0.7203748f, 0.68979263f, 0.94391155f, 0.30140227f, + 0.19580519f, 0.6663505f, 0.63076067f, 0.86896664f, 0.06484294f, 0.46668667f, 0.40556252f, + 0.46882278f, 0.102760494f, 0.73050886f, 0.8356961f, 0.6755136f, 0.5681396f, 0.109577894f, + 0.80634f, 0.5885344f, 0.30270886f, 0.5224626f, 0.1426636f, 0.61006856f, 0.7554362f, + 0.5708561f, 0.5937015f, 0.45749134f, 0.37791175f, 0.93248576f, 0.464967f, 0.46046263f, + 0.18158185f, 0.018330276f, 0.9106033f, 0.2911517f, 0.74790317f, 0.69202286f, 0.6075949f, + 0.85149264f, 0.94095474f, 0.17537653f, 0.33921933f, 0.110060334f, 0.36293298f, 0.46352726f, + 0.02775067f, 0.34128642f, 0.9567719f, 0.8162452f, 0.6936434f, 0.17896658f, 0.18262434f, + 0.13349658f, 0.95527196f, 0.6541582f, 0.002533555f, 0.12373549f, 0.8386668f, 0.49692506f, + 0.5846437f, 0.2106899f, 0.21170747f, 0.47471708f, 0.3711096f, 0.40584308f, 0.92123467f, + 0.44139355f, 0.98912156f, 0.6823429f, 0.008196831f, 0.6082054f, 0.77075267f, 0.4575745f, + 0.52732253f, 0.084977865f, 0.5830078f, 0.6437313f, 0.7027664f, 0.59101045f, 0.21356237f, + 0.27901447f, 0.32865065f, 0.029210031f, 0.0028281808f, 0.44412827f, 0.0028039217f, 0.998863f, + 0.35483408f, 0.33203715f, 0.59384584f, 0.3924346f, 0.75354826f, 0.78473866f, 0.33041215f, + 0.11807448f, 0.74022657f, 0.81752694f, 0.22417879f, 0.6959876f, 0.021027386f, 0.7004118f, + 0.5640625f, 0.82573634f, 0.010389388f, 0.79243237f, 0.3391887f, 0.7897731f, 0.15555972f, + 0.18762505f, 0.8451145f, 0.38050216f, 0.57453954f, 0.86466515f, 0.9299548f, 0.8687942f, + 0.38184917f, 0.8531213f, 0.87651116f, 0.21878648f, 0.610987f, 0.54310906f, 0.5017944f, + 0.15232486f, 0.4647144f, 0.20985329f, 0.89526993f, 0.08409953f, 0.846228f, 0.060534656f, + 0.70841146f, 0.62002844f, 0.90842223f, 0.7173401f, 0.19575512f, 0.54321826f, 0.22779316f, + 0.8902793f, 0.40386343f, 0.16297412f, 0.7796456f, 0.7209991f, 0.57273495f, 0.80687743f, + 0.70873487f, 0.45478833f, 0.97999126f, 0.45059657f, 0.5500629f, 0.4574679f, 0.46920043f, + 0.5106411f, 0.16112137f, 0.30643195f, 0.41949958f, 0.6965155f, 0.45372677f, 0.7416339f, + 0.07971662f, 0.019000888f, 0.29456508f, 0.63610595f, 0.1810705f, 0.7203987f, 0.71254706f, + 0.4558406f, 0.18640989f, 0.04116726f, 0.7322634f, 0.66699976f, 0.045647383f, 0.11515516f, + 0.5025297f, 0.020907402f, 0.50736845f, 0.34275955f, 0.21446168f, 0.51500404f, 0.093204916f, + 0.9095519f, 0.949713f, 0.16832972f, 0.31914896f, 0.49500751f, 0.10780871f, 0.03719926f, + 0.16613615f, 0.1210829f, 0.88132215f, 0.6184779f, 0.8939942f, 0.24608392f, 0.3340807f, + 0.37963718f, 0.4086216f, 0.4861309f, 0.96492773f, 0.27750528f, 0.82491904f, 0.44960213f, + 0.07433933f, 0.014267266f, 0.5327839f, 0.78514075f, 0.0019600391f, 0.112906754f, 0.0718022f, + 0.9215306f, 0.89411324f, 0.19682491f, 0.8896367f, 0.16405362f, 0.044967473f, 0.2238177f, + 0.29950172f, 0.45508182f, 0.652841f, 0.56343627f, 0.028514147f, 0.59348315f, 0.46439964f, + 0.7485309f, 0.27697015f, 0.70751995f, 0.14157385f, 0.41028368f, 0.9382696f, 0.83076525f, + 0.16925985f, 0.7039947f, 0.8105199f, 0.4000479f, 0.50734824f, 0.13979793f, 0.4297228f, + 0.9553564f, 0.7211761f, 0.072479784f, 0.9391413f, 0.13516217f, 0.14382833f, 0.29872334f, + 0.06472236f, 0.3285427f, 0.2932666f, 0.06387442f, 0.30173403f, 0.89418864f, 0.26674682f, + 0.061959863f, 0.17714864f, 0.36635226f, 0.22465849f, 0.8074024f, 0.4161513f, 0.5806135f, + 0.34574044f, 0.9039644f, 0.47833657f, 0.7588931f, 0.18114948f, 0.033798456f, 0.38355058f, + 0.8689936f, 0.017748117f, 0.3073569f, 0.15338135f, 0.25849247f, 0.2459296f, 0.026683688f, + 0.7071423f, 0.49425328f, 0.50869817f, 0.013295233f, 0.7934831f, 0.37678176f, 0.97975004f, + 0.007119417f}; const std::vector fc1_experts_bias = { 0.71526206f, 0.7472273f, 0.18946046f, 0.6239893f, 0.86909235f, 0.5726507f, 0.3942092f, 0.5369412f, 0.44638616f, 0.7517496f, 0.16049433f, 0.75355124f, 0.7818118f, 0.19706267f, 0.9082818f, 0.9910924f, @@ -256,19 +302,8 @@ TEST(MoETest, MoETest_Gelu) { 0.565234f, 0.17098689f, 0.10810414f, 0.43916586f, 0.3535297f, 0.45673048f, 0.3853893f, 0.18613164f, 1.3354061f, 0.5049282f, 0.72775036f, 0.90331376f, 1.2945517f, 0.9123066f, 1.1995136f, 0.7708638f}; - RunMoETest(input, - router_probs, - fc1_experts_weights, - fc2_experts_weights, - {}, - fc1_experts_bias, - fc2_experts_bias, - output, - num_rows, - num_experts, - hidden_size, - inter_size, - "gelu"); + RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, + output, num_rows, num_experts, hidden_size, inter_size, "gelu"); } TEST(MoETest, MoETest_Relu) { @@ -286,135 +321,145 @@ TEST(MoETest, MoETest_Relu) { -0.08146476f, -0.40439552f, 1.0100367f, -0.7724162f, -0.08113786f, -0.36328858f, 0.3688482f, -0.013465762f, -0.32420647f, -0.3815508f, 0.79585606f, 0.14430691f, -0.21869831f, 0.11483674f, -0.11992836f, 0.35216537f}; const std::vector fc1_experts_weights = { - 0.81960344f, 0.9296998f, 0.45050132f, 0.38805157f, 0.50729614f, 0.47014588f, 0.62020564f, 0.6401168f, - 0.045871615f, 0.31548113f, 0.92106473f, 0.6947775f, 0.4751312f, 0.19854712f, 0.19409746f, 0.052116573f, - 0.3370188f, 0.6688521f, 0.8188108f, 0.73084867f, 0.058027983f, 0.19931877f, 0.42109168f, 0.98367476f, - 0.57232875f, 0.37051463f, 0.7068576f, 0.30955923f, 0.17637217f, 0.8649436f, 0.2726491f, 0.39976662f, - 0.0025978684f, 0.8346353f, 0.8788173f, 0.6822241f, 0.1513629f, 0.0065300465f, 0.093910515f, 0.8728501f, - 0.7400529f, 0.9207522f, 0.76193494f, 0.6265461f, 0.49510366f, 0.11974698f, 0.07161391f, 0.032325685f, - 0.704681f, 0.254516f, 0.3993737f, 0.21224737f, 0.40888822f, 0.14808255f, 0.17329216f, 0.6658554f, - 0.3514018f, 0.8086716f, 0.33959562f, 0.13321638f, 0.41178054f, 0.2576263f, 0.3470292f, 0.024002194f, - 0.77974546f, 0.15189773f, 0.75130886f, 0.7268921f, 0.85721636f, 0.11647397f, 0.8595984f, 0.2636242f, - 0.6855346f, 0.96955734f, 0.42948407f, 0.49613327f, 0.38488472f, 0.08250773f, 0.73995143f, 0.003641069f, - 0.81039995f, 0.87411255f, 0.9728532f, 0.38206023f, 0.08917904f, 0.61241513f, 0.77621365f, 0.0023456216f, - 0.38650817f, 0.20027226f, 0.45626813f, 0.25389326f, 0.2956162f, 0.34127057f, 0.024847984f, 0.91025376f, - 0.9191656f, 0.42156547f, 0.44305897f, 0.29594004f, 0.04846859f, 0.013427794f, 0.6858292f, 0.22547692f, - 0.17856151f, 0.4609884f, 0.33349442f, 0.3382396f, 0.5160656f, 0.3939438f, 0.3278438f, 0.26059705f, - 0.0930863f, 0.9192536f, 0.29990643f, 0.63248974f, 0.32651705f, 0.54063064f, 0.9661502f, 0.73036134f, - 0.06670016f, 0.6984514f, 0.9746214f, 0.63154167f, 0.83521235f, 0.99294376f, 0.4233855f, 0.6037772f, - 0.15248245f, 0.39696145f, 0.8702919f, 0.7563229f, 0.18360549f, 0.099057496f, 0.15831816f, 0.00656116f, - 0.114180505f, 0.3763513f, 0.8374386f, 0.5836911f, 0.11969727f, 0.09888804f, 0.74873763f, 0.12807935f, - 0.43843627f, 0.739853f, 0.26859397f, 0.44548005f, 0.45647776f, 0.38170832f, 0.24648392f, 0.054280818f, - 0.0958215f, 0.23226917f, 0.98291886f, 0.25849265f, 0.16423601f, 0.6211971f, 0.63780516f, 0.77395487f, - 0.8800602f, 0.7784371f, 0.004249513f, 0.5443443f, 0.80287653f, 0.45378727f, 0.20536041f, 0.9766699f, - 0.31298608f, 0.21532774f, 0.04922247f, 0.52233416f, 0.72156656f, 0.6106814f, 0.59887487f, 0.12080628f, - 0.03305638f, 0.5088047f, 0.95591706f, 0.7884607f, 0.20888287f, 0.43509573f, 0.13140821f, 0.2587883f, - 0.5905492f, 0.77226925f, 0.91418463f, 0.04094696f, 0.8343076f, 0.14735395f, 0.6872336f, 0.92312264f, - 0.5070212f, 0.9549045f, 0.07397425f, 0.3090204f, 0.79162645f, 0.39106607f, 0.39764988f, 0.29160416f, - 0.84465307f, 0.7452516f, 0.66022503f, 0.21901816f, 0.09412521f, 0.5540803f, 0.6481394f, 0.26914406f, - 0.36010116f, 0.83768386f, 0.53982985f, 0.52255917f, 0.37694973f, 0.04720515f, 0.029871285f, 0.26099247f, - 0.2458393f, 0.6557768f, 0.35444462f, 0.30438894f, 0.9767149f, 0.67416143f, 0.85645115f, 0.25794363f, - 0.2957666f, 0.68377024f, 0.16686243f, 0.17314798f, 0.47585016f, 0.31711966f, 0.125171f, 0.7965795f, - 0.90208143f, 0.58111167f, 0.41294336f, 0.036863506f, 0.31788063f, 0.6272928f, 0.73576546f, 0.43679124f, - 0.30232358f, 0.77861303f, 0.10180014f, 0.816009f, 0.30602258f, 0.5076527f, 0.40119207f, 0.5606195f, - 0.3489008f, 0.8635635f, 0.48700142f, 0.89029974f, 0.98074025f, 0.25640452f, 0.13524544f, 0.901151f, - 0.89180696f, 0.11822635f, 0.46134835f, 0.006936848f, 0.09070045f, 0.59657127f, 0.6330173f, 0.6059905f, - 0.36391765f, 0.96128887f, 0.571489f, 0.2049576f, 0.4716931f, 0.6200726f, 0.67509633f, 0.14645958f, - 0.6873948f, 0.24455917f, 0.08452982f, 0.22689629f, 0.9822047f, 0.9274289f, 0.9477422f, 0.7935056f, - 0.87772477f, 0.43307513f, 0.22488606f, 0.7498283f, 0.24090862f, 0.16256708f, 0.34033298f, 0.6049296f, - 0.7573983f, 0.3057955f, 0.20571685f, 0.56744653f, 0.2052834f, 0.17446929f, 0.76062596f, 0.4160077f, - 0.9568925f, 0.9863913f, 0.64955276f, 0.67207885f, 0.61514187f, 0.50783044f, 0.46363378f, 0.50687206f, - 0.6867124f, 0.9648854f, 0.37042046f, 0.2886421f, 0.37891757f, 0.25843787f, 0.58501935f, 0.8732242f, - 0.8909887f, 0.72956276f, 0.13203424f, 0.23164761f, 0.3901443f, 0.40783793f, 0.54112387f, 0.041014254f, - 0.65562236f, 0.11856395f, 0.18362767f, 0.08430874f, 0.9356598f, 0.026530087f, 0.8771834f, 0.48319155f, - 0.4418506f, 0.81273925f, 0.4537862f, 0.81357706f, 0.8615075f, 0.06589496f, 0.692392f, 0.5943895f, - 0.60750586f, 0.5729957f, 0.6367655f, 0.2594666f, 0.43602943f, 0.97506f, 0.83592474f, 0.48121578f, - 0.029734552f, 0.5219139f, 0.15951324f, 0.90659577f, 0.19645631f, 0.4638992f, 0.38902867f, 0.5889769f, - 0.9705138f, 0.5475096f, 0.789582f, 0.8881108f, 0.9036556f, 0.32732427f, 0.38817167f, 0.7409689f, - 0.36356616f, 0.734132f, 0.39076614f, 0.16087383f, 0.70352167f, 0.576659f, 0.7229242f, 0.996743f, - 0.84136647f, 0.97399056f, 0.5267614f, 0.06989372f, 0.14923638f, 0.18941313f, 0.059375823f, 0.24937624f, - 0.039716125f, 0.038692355f, 0.20122272f, 0.0070830584f, 0.19309378f, 0.69065434f, 0.9170264f, 0.3512686f, - 0.3545606f, 0.76697665f, 0.25331455f, 0.26358372f, 0.80806476f, 0.064349174f, 0.5611374f, 0.941691f, - 0.58574325f, 0.6359719f, 0.20880443f, 0.49310172f, 0.5274922f, 0.62271714f, 0.694273f, 0.9344639f, - 0.11835027f, 0.51498765f, 0.25018185f, 0.10446805f, 0.45996118f, 0.059881568f, 0.8489496f, 0.5579074f, - 0.23052096f, 0.76128954f, 0.02678603f, 0.3066004f, 0.40259063f, 0.07512486f, 0.18205583f, 0.4183907f, - 0.8793823f, 0.9828271f, 0.8181312f, 0.20143801f, 0.17288941f, 0.9363466f, 0.6768587f, 0.51328385f, - 0.56766605f, 0.098151624f, 0.33305728f, 0.98130906f, 0.3766839f, 0.47491795f, 0.08483446f, 0.22029644f, - 0.4897902f, 0.18942028f, 0.4379952f, 0.7034796f, 0.0109113455f, 0.64850605f, 0.16939592f, 0.25597447f, - 0.69195485f, 0.8975601f, 0.36334568f, 0.29471546f, 0.04788208f, 0.24217117f, 0.062181532f, 0.38556474f, - 0.6020277f, 0.03156215f, 0.93655676f, 0.81369543f, 0.010527074f, 0.2611835f, 0.6630776f, 0.3972702f, - 0.44551176f, 0.27424216f, 0.9016098f, 0.22050089f, 0.9146384f, 0.53226113f, 0.6005109f, 0.8900659f, - 0.4176172f, 0.21532834f, 0.4191329f, 0.9055267f, 0.12900633f, 0.6134902f, 0.008604288f, 0.76215106f, - 0.68473387f, 0.5211961f, 0.71459657f, 0.50056237f, 0.7766764f, 0.10418975f, 0.42657375f, 0.7218073f, - 0.9979084f, 0.7546957f, 0.1364128f, 0.8845484f, 0.38850087f, 0.39324278f, 0.04554516f, 0.42129284f, - 0.8536634f, 0.5697224f, 0.20877302f, 0.65390605f, 0.3396778f, 0.956497f, 0.066022694f, 0.34206223f, - 0.017213225f, 0.3030849f, 0.6576238f, 0.9813073f, 0.58397317f, 0.99017924f, 0.59782606f, 0.788768f, - 0.9008311f, 0.91796166f, 0.22013813f, 0.959695f, 0.80288273f, 0.2662105f, 0.26139832f, 0.080626905f}; + 0.81960344f, 0.3370188f, 0.0025978684f, 0.704681f, 0.77974546f, 0.81039995f, 0.9191656f, + 0.0930863f, 0.9296998f, 0.6688521f, 0.8346353f, 0.254516f, 0.15189773f, 0.87411255f, + 0.42156547f, 0.9192536f, 0.45050132f, 0.8188108f, 0.8788173f, 0.3993737f, 0.75130886f, + 0.9728532f, 0.44305897f, 0.29990643f, 0.38805157f, 0.73084867f, 0.6822241f, 0.21224737f, + 0.7268921f, 0.38206023f, 0.29594004f, 0.63248974f, 0.50729614f, 0.058027983f, 0.1513629f, + 0.40888822f, 0.85721636f, 0.08917904f, 0.04846859f, 0.32651705f, 0.47014588f, 0.19931877f, + 0.0065300465f, 0.14808255f, 0.11647397f, 0.61241513f, 0.013427794f, 0.54063064f, 0.62020564f, + 0.42109168f, 0.093910515f, 0.17329216f, 0.8595984f, 0.77621365f, 0.6858292f, 0.9661502f, + 0.6401168f, 0.98367476f, 0.8728501f, 0.6658554f, 0.2636242f, 0.0023456216f, 0.22547692f, + 0.73036134f, 0.045871615f, 0.57232875f, 0.7400529f, 0.3514018f, 0.6855346f, 0.38650817f, + 0.17856151f, 0.06670016f, 0.31548113f, 0.37051463f, 0.9207522f, 0.8086716f, 0.96955734f, + 0.20027226f, 0.4609884f, 0.6984514f, 0.92106473f, 0.7068576f, 0.76193494f, 0.33959562f, + 0.42948407f, 0.45626813f, 0.33349442f, 0.9746214f, 0.6947775f, 0.30955923f, 0.6265461f, + 0.13321638f, 0.49613327f, 0.25389326f, 0.3382396f, 0.63154167f, 0.4751312f, 0.17637217f, + 0.49510366f, 0.41178054f, 0.38488472f, 0.2956162f, 0.5160656f, 0.83521235f, 0.19854712f, + 0.8649436f, 0.11974698f, 0.2576263f, 0.08250773f, 0.34127057f, 0.3939438f, 0.99294376f, + 0.19409746f, 0.2726491f, 0.07161391f, 0.3470292f, 0.73995143f, 0.024847984f, 0.3278438f, + 0.4233855f, 0.052116573f, 0.39976662f, 0.032325685f, 0.024002194f, 0.003641069f, 0.91025376f, + 0.26059705f, 0.6037772f, 0.15248245f, 0.43843627f, 0.8800602f, 0.03305638f, 0.5070212f, + 0.36010116f, 0.2957666f, 0.30232358f, 0.39696145f, 0.739853f, 0.7784371f, 0.5088047f, + 0.9549045f, 0.83768386f, 0.68377024f, 0.77861303f, 0.8702919f, 0.26859397f, 0.004249513f, + 0.95591706f, 0.07397425f, 0.53982985f, 0.16686243f, 0.10180014f, 0.7563229f, 0.44548005f, + 0.5443443f, 0.7884607f, 0.3090204f, 0.52255917f, 0.17314798f, 0.816009f, 0.18360549f, + 0.45647776f, 0.80287653f, 0.20888287f, 0.79162645f, 0.37694973f, 0.47585016f, 0.30602258f, + 0.099057496f, 0.38170832f, 0.45378727f, 0.43509573f, 0.39106607f, 0.04720515f, 0.31711966f, + 0.5076527f, 0.15831816f, 0.24648392f, 0.20536041f, 0.13140821f, 0.39764988f, 0.029871285f, + 0.125171f, 0.40119207f, 0.00656116f, 0.054280818f, 0.9766699f, 0.2587883f, 0.29160416f, + 0.26099247f, 0.7965795f, 0.5606195f, 0.114180505f, 0.0958215f, 0.31298608f, 0.5905492f, + 0.84465307f, 0.2458393f, 0.90208143f, 0.3489008f, 0.3763513f, 0.23226917f, 0.21532774f, + 0.77226925f, 0.7452516f, 0.6557768f, 0.58111167f, 0.8635635f, 0.8374386f, 0.98291886f, + 0.04922247f, 0.91418463f, 0.66022503f, 0.35444462f, 0.41294336f, 0.48700142f, 0.5836911f, + 0.25849265f, 0.52233416f, 0.04094696f, 0.21901816f, 0.30438894f, 0.036863506f, 0.89029974f, + 0.11969727f, 0.16423601f, 0.72156656f, 0.8343076f, 0.09412521f, 0.9767149f, 0.31788063f, + 0.98074025f, 0.09888804f, 0.6211971f, 0.6106814f, 0.14735395f, 0.5540803f, 0.67416143f, + 0.6272928f, 0.25640452f, 0.74873763f, 0.63780516f, 0.59887487f, 0.6872336f, 0.6481394f, + 0.85645115f, 0.73576546f, 0.13524544f, 0.12807935f, 0.77395487f, 0.12080628f, 0.92312264f, + 0.26914406f, 0.25794363f, 0.43679124f, 0.901151f, 0.89180696f, 0.6873948f, 0.7573983f, + 0.6867124f, 0.65562236f, 0.60750586f, 0.9705138f, 0.84136647f, 0.11822635f, 0.24455917f, + 0.3057955f, 0.9648854f, 0.11856395f, 0.5729957f, 0.5475096f, 0.97399056f, 0.46134835f, + 0.08452982f, 0.20571685f, 0.37042046f, 0.18362767f, 0.6367655f, 0.789582f, 0.5267614f, + 0.006936848f, 0.22689629f, 0.56744653f, 0.2886421f, 0.08430874f, 0.2594666f, 0.8881108f, + 0.06989372f, 0.09070045f, 0.9822047f, 0.2052834f, 0.37891757f, 0.9356598f, 0.43602943f, + 0.9036556f, 0.14923638f, 0.59657127f, 0.9274289f, 0.17446929f, 0.25843787f, 0.026530087f, + 0.97506f, 0.32732427f, 0.18941313f, 0.6330173f, 0.9477422f, 0.76062596f, 0.58501935f, + 0.8771834f, 0.83592474f, 0.38817167f, 0.059375823f, 0.6059905f, 0.7935056f, 0.4160077f, + 0.8732242f, 0.48319155f, 0.48121578f, 0.7409689f, 0.24937624f, 0.36391765f, 0.87772477f, + 0.9568925f, 0.8909887f, 0.4418506f, 0.029734552f, 0.36356616f, 0.039716125f, 0.96128887f, + 0.43307513f, 0.9863913f, 0.72956276f, 0.81273925f, 0.5219139f, 0.734132f, 0.038692355f, + 0.571489f, 0.22488606f, 0.64955276f, 0.13203424f, 0.4537862f, 0.15951324f, 0.39076614f, + 0.20122272f, 0.2049576f, 0.7498283f, 0.67207885f, 0.23164761f, 0.81357706f, 0.90659577f, + 0.16087383f, 0.0070830584f, 0.4716931f, 0.24090862f, 0.61514187f, 0.3901443f, 0.8615075f, + 0.19645631f, 0.70352167f, 0.19309378f, 0.6200726f, 0.16256708f, 0.50783044f, 0.40783793f, + 0.06589496f, 0.4638992f, 0.576659f, 0.69065434f, 0.67509633f, 0.34033298f, 0.46363378f, + 0.54112387f, 0.692392f, 0.38902867f, 0.7229242f, 0.9170264f, 0.14645958f, 0.6049296f, + 0.50687206f, 0.041014254f, 0.5943895f, 0.5889769f, 0.996743f, 0.3512686f, 0.3545606f, + 0.11835027f, 0.8793823f, 0.4897902f, 0.6020277f, 0.4176172f, 0.9979084f, 0.017213225f, + 0.76697665f, 0.51498765f, 0.9828271f, 0.18942028f, 0.03156215f, 0.21532834f, 0.7546957f, + 0.3030849f, 0.25331455f, 0.25018185f, 0.8181312f, 0.4379952f, 0.93655676f, 0.4191329f, + 0.1364128f, 0.6576238f, 0.26358372f, 0.10446805f, 0.20143801f, 0.7034796f, 0.81369543f, + 0.9055267f, 0.8845484f, 0.9813073f, 0.80806476f, 0.45996118f, 0.17288941f, 0.0109113455f, + 0.010527074f, 0.12900633f, 0.38850087f, 0.58397317f, 0.064349174f, 0.059881568f, 0.9363466f, + 0.64850605f, 0.2611835f, 0.6134902f, 0.39324278f, 0.99017924f, 0.5611374f, 0.8489496f, + 0.6768587f, 0.16939592f, 0.6630776f, 0.008604288f, 0.04554516f, 0.59782606f, 0.941691f, + 0.5579074f, 0.51328385f, 0.25597447f, 0.3972702f, 0.76215106f, 0.42129284f, 0.788768f, + 0.58574325f, 0.23052096f, 0.56766605f, 0.69195485f, 0.44551176f, 0.68473387f, 0.8536634f, + 0.9008311f, 0.6359719f, 0.76128954f, 0.098151624f, 0.8975601f, 0.27424216f, 0.5211961f, + 0.5697224f, 0.91796166f, 0.20880443f, 0.02678603f, 0.33305728f, 0.36334568f, 0.9016098f, + 0.71459657f, 0.20877302f, 0.22013813f, 0.49310172f, 0.3066004f, 0.98130906f, 0.29471546f, + 0.22050089f, 0.50056237f, 0.65390605f, 0.959695f, 0.5274922f, 0.40259063f, 0.3766839f, + 0.04788208f, 0.9146384f, 0.7766764f, 0.3396778f, 0.80288273f, 0.62271714f, 0.07512486f, + 0.47491795f, 0.24217117f, 0.53226113f, 0.10418975f, 0.956497f, 0.2662105f, 0.694273f, + 0.18205583f, 0.08483446f, 0.062181532f, 0.6005109f, 0.42657375f, 0.066022694f, 0.26139832f, + 0.9344639f, 0.4183907f, 0.22029644f, 0.38556474f, 0.8900659f, 0.7218073f, 0.34206223f, + 0.080626905f}; const std::vector fc2_experts_weights = { - 0.6255686f, 0.09472537f, 0.71121234f, 0.65789884f, 0.065598905f, 0.63625044f, 0.45933473f, 0.7284089f, - 0.7868948f, 0.0029274821f, 0.95854944f, 0.919321f, 0.6989418f, 0.043019474f, 0.32138962f, 0.35509557f, - 0.37150103f, 0.78196156f, 0.6817853f, 0.89608955f, 0.31273842f, 0.6682699f, 0.6778976f, 0.08370459f, - 0.014990091f, 0.24055547f, 0.84227383f, 0.029270172f, 0.0647831f, 0.7801003f, 0.7697645f, 0.91119635f, - 0.12253064f, 0.13405013f, 0.75649333f, 0.9348151f, 0.7991694f, 0.57832605f, 0.66478735f, 0.97456336f, - 0.17739785f, 0.2729941f, 0.8497335f, 0.15788019f, 0.22429371f, 0.86499554f, 0.65776104f, 0.661535f, - 0.2880798f, 0.49309975f, 0.9576164f, 0.19988996f, 0.5039311f, 0.73779976f, 0.15482187f, 0.98558843f, - 0.25019473f, 0.379932f, 0.36471486f, 0.17417055f, 0.009367704f, 0.7819258f, 0.63283706f, 0.031699598f, - 0.1781866f, 0.994184f, 0.6911175f, 0.7006223f, 0.20085096f, 0.28080195f, 0.42452294f, 0.40856004f, - 0.15737581f, 0.5411925f, 0.549694f, 0.4366895f, 0.5693159f, 0.3018247f, 0.63012594f, 0.6885702f, - 0.2366305f, 0.004210472f, 0.7617172f, 0.61926836f, 0.24570602f, 0.981851f, 0.273876f, 0.8378734f, - 0.75366426f, 0.080795944f, 0.82247066f, 0.040263534f, 0.22299266f, 0.41664255f, 0.16297674f, 0.98845494f, - 0.39971018f, 0.69859487f, 0.053544044f, 0.7878332f, 0.34460813f, 0.11966437f, 0.5731115f, 0.7422309f, - 0.93269855f, 0.19460368f, 0.25394785f, 0.59613144f, 0.6356306f, 0.6922361f, 0.7744376f, 0.38662314f, - 0.7777848f, 0.8686458f, 0.36938924f, 0.8557286f, 0.74428976f, 0.9410264f, 0.21586305f, 0.2530955f, - 0.35543054f, 0.52536315f, 0.8000995f, 0.21456867f, 0.750327f, 0.3208093f, 0.80205464f, 0.47626138f, - 0.061956525f, 0.22487706f, 0.13812399f, 0.74798125f, 0.1647259f, 0.45834088f, 0.6078779f, 0.22580266f, - 0.644235f, 0.011788309f, 0.14224577f, 0.0469383f, 0.34876132f, 0.3178513f, 0.5715967f, 0.40754277f, - 0.735041f, 0.9583977f, 0.67939556f, 0.30301625f, 0.031807184f, 0.68110096f, 0.25227106f, 0.75443816f, - 0.83424246f, 0.69286025f, 0.9691554f, 0.9748982f, 0.60586995f, 0.13568163f, 0.94672066f, 0.26275212f, - 0.2638232f, 0.9183893f, 0.88740516f, 0.65107566f, 0.5313419f, 0.07941705f, 0.44809794f, 0.9795632f, - 0.6273294f, 0.542809f, 0.3961745f, 0.32560885f, 0.79801136f, 0.53083426f, 0.8252871f, 0.4115007f, - 0.7184546f, 0.70638496f, 0.57973206f, 0.8141865f, 0.81332296f, 0.96346164f, 0.88438797f, 0.37215167f, - 0.0766899f, 0.5914087f, 0.49563587f, 0.3695873f, 0.41627264f, 0.5235164f, 0.86481494f, 0.6558706f, - 0.32245284f, 0.29438752f, 0.37618434f, 0.3067485f, 0.9496114f, 0.76482266f, 0.95148784f, 0.5015968f, - 0.60083544f, 0.67338234f, 0.026723444f, 0.5446483f, 0.466555f, 0.21967298f, 0.112026334f, 0.9426372f, - 0.906533f, 0.73173434f, 0.97712487f, 0.29709607f, 0.41363865f, 0.6893093f, 0.4173867f, 0.4018826f, - 0.086719275f, 0.63433063f, 0.1978364f, 0.5181831f, 0.9874878f, 0.34609234f, 0.34240413f, 0.8016564f, - 0.31617337f, 0.4570613f, 0.96686924f, 0.29501313f, 0.14229488f, 0.22017813f, 0.36137718f, 0.26275063f, - 0.24053413f, 0.70197225f, 0.58496886f, 0.33996922f, 0.11154431f, 0.34257007f, 0.28898042f, 0.33729053f, - 0.048938513f, 0.60771453f, 0.13263822f, 0.11060041f, 0.091483414f, 0.70869184f, 0.19898665f, 0.29362458f, - 0.8919203f, 0.7654821f, 0.7866956f, 0.02524674f, 0.1414501f, 0.3112445f, 0.9130488f, 0.5511502f, - 0.12605143f, 0.5031309f, 0.11166459f, 0.39045036f, 0.36251247f, 0.9328308f, 0.65486836f, 0.41281444f, - 0.5844644f, 0.35566723f, 0.6964502f, 0.6977819f, 0.63427305f, 0.30511153f, 0.92657536f, 0.42781502f, - 0.30534166f, 0.813157f, 0.90752834f, 0.9975799f, 0.64812917f, 0.32955307f, 0.753946f, 0.92897725f, - 0.009582937f, 0.43805653f, 0.15901726f, 0.5931799f, 0.7067924f, 0.39670604f, 0.45817143f, 0.7250554f, - 0.41596514f, 0.08011025f, 0.900068f, 0.24834275f, 0.44507074f, 0.5471632f, 0.46995157f, 0.029657006f, - 0.7294f, 0.27288425f, 0.2406702f, 0.6194577f, 0.23906898f, 0.26892018f, 0.33152503f, 0.3121612f, - 0.29118127f, 0.36515707f, 0.6299379f, 0.095391035f, 0.19735986f, 0.5072957f, 0.56953406f, 0.77614623f, - 0.14877802f, 0.65959847f, 0.7841949f, 0.7776301f, 0.03428924f, 0.3091979f, 0.07021719f, 0.18359429f, - 0.77849144f, 0.42534047f, 0.7123557f, 0.20649683f, 0.57597995f, 0.19757104f, 0.749946f, 0.2813105f, - 0.37462044f, 0.06618434f, 0.50165176f, 0.9747401f, 0.7426891f, 0.23322952f, 0.50672436f, 0.44517577f, - 0.09746289f, 0.89204556f, 0.50806034f, 0.6052985f, 0.2980855f, 0.26604044f, 0.5824448f, 0.68485546f, - 0.612149f, 0.25902748f, 0.9854489f, 0.4263978f, 0.19379246f, 0.26614368f, 0.9922104f, 0.5000241f, - 0.4321279f, 0.2919191f, 0.3689273f, 0.078885734f, 0.10265827f, 0.79264474f, 0.9277247f, 0.9771502f, - 0.13902885f, 0.77043164f, 0.19051671f, 0.7982801f, 0.86077714f, 0.8869355f, 0.86002564f, 0.81278664f, - 0.5097318f, 0.7297412f, 0.32111454f, 0.7177174f, 0.33929902f, 0.49160433f, 0.064810574f, 0.3692627f, - 0.23706353f, 0.3313396f, 0.18070674f, 0.05027789f, 0.53255826f, 0.8244896f, 0.9553747f, 0.7917771f, - 0.24083132f, 0.005495131f, 0.6896569f, 0.78015697f, 0.07074398f, 0.67929304f, 0.9227386f, 0.5302883f, - 0.19877058f, 0.90993816f, 0.71350795f, 0.8311006f, 0.16185725f, 0.79097277f, 0.15846318f, 0.99474716f, - 0.28815013f, 0.80128354f, 0.6001208f, 0.63250524f, 0.4233225f, 0.7053677f, 0.29161406f, 0.028710365f, - 0.30789846f, 0.8917693f, 0.36836517f, 0.6571592f, 0.3151368f, 0.8750746f, 0.7992451f, 0.6765068f, - 0.24441916f, 0.091435075f, 0.5188247f, 0.20667112f, 0.9110969f, 0.019512117f, 0.72343415f, 0.998457f, - 0.7504142f, 0.6704894f, 0.01892668f, 0.9809466f, 0.41447622f, 0.032795787f, 0.9935814f, 0.29653466f, - 0.4646262f, 0.95763975f, 0.15339965f, 0.14625502f, 0.58130866f, 0.43307304f, 0.6151709f, 0.08064735f, - 0.5149533f, 0.27762014f, 0.25419557f, 0.04218155f, 0.7651092f, 0.59631824f, 0.077278376f, 0.89677596f, - 0.6508104f, 0.5927816f, 0.2064318f, 0.57540226f, 0.9817701f, 0.84294224f, 0.11056489f, 0.9564106f, - 0.5387549f, 0.74048257f, 0.88833815f, 0.9262546f, 0.11023259f, 0.93783194f, 0.16041255f, 0.53748304f, - 0.1506182f, 0.39038336f, 0.47727865f, 0.44018233f, 0.42101204f, 0.53943527f, 0.99320936f, 0.79050577f, - 0.77973497f, 0.7001237f, 0.88709056f, 0.4769255f, 0.5397561f, 0.60289854f, 0.06393474f, 0.09722155f, - 0.5613007f, 0.30437487f, 0.49082512f, 0.3852706f, 0.5778314f, 0.8253078f, 0.33417904f, 0.9004303f, - 0.8947809f, 0.11625093f, 0.11388689f, 0.09546256f, 0.22598988f, 0.30536187f, 0.46236527f, 0.3784039f, - 0.24737573f, 0.3411532f, 0.31912774f, 0.9905191f, 0.31468558f, 0.14199954f, 0.7078488f, 0.47111923f, - 0.882782f, 0.8124163f, 0.9593644f, 0.13382024f, 0.8214317f, 0.9196194f, 0.25308424f, 0.95958996f}; + 0.6255686f, 0.7868948f, 0.37150103f, 0.014990091f, 0.12253064f, 0.17739785f, 0.2880798f, 0.25019473f, + 0.1781866f, 0.15737581f, 0.2366305f, 0.75366426f, 0.39971018f, 0.93269855f, 0.7777848f, 0.35543054f, + 0.09472537f, 0.0029274821f, 0.78196156f, 0.24055547f, 0.13405013f, 0.2729941f, 0.49309975f, 0.379932f, + 0.994184f, 0.5411925f, 0.004210472f, 0.080795944f, 0.69859487f, 0.19460368f, 0.8686458f, 0.52536315f, + 0.71121234f, 0.95854944f, 0.6817853f, 0.84227383f, 0.75649333f, 0.8497335f, 0.9576164f, 0.36471486f, + 0.6911175f, 0.549694f, 0.7617172f, 0.82247066f, 0.053544044f, 0.25394785f, 0.36938924f, 0.8000995f, + 0.65789884f, 0.919321f, 0.89608955f, 0.029270172f, 0.9348151f, 0.15788019f, 0.19988996f, 0.17417055f, + 0.7006223f, 0.4366895f, 0.61926836f, 0.040263534f, 0.7878332f, 0.59613144f, 0.8557286f, 0.21456867f, + 0.065598905f, 0.6989418f, 0.31273842f, 0.0647831f, 0.7991694f, 0.22429371f, 0.5039311f, 0.009367704f, + 0.20085096f, 0.5693159f, 0.24570602f, 0.22299266f, 0.34460813f, 0.6356306f, 0.74428976f, 0.750327f, + 0.63625044f, 0.043019474f, 0.6682699f, 0.7801003f, 0.57832605f, 0.86499554f, 0.73779976f, 0.7819258f, + 0.28080195f, 0.3018247f, 0.981851f, 0.41664255f, 0.11966437f, 0.6922361f, 0.9410264f, 0.3208093f, + 0.45933473f, 0.32138962f, 0.6778976f, 0.7697645f, 0.66478735f, 0.65776104f, 0.15482187f, 0.63283706f, + 0.42452294f, 0.63012594f, 0.273876f, 0.16297674f, 0.5731115f, 0.7744376f, 0.21586305f, 0.80205464f, + 0.7284089f, 0.35509557f, 0.08370459f, 0.91119635f, 0.97456336f, 0.661535f, 0.98558843f, 0.031699598f, + 0.40856004f, 0.6885702f, 0.8378734f, 0.98845494f, 0.7422309f, 0.38662314f, 0.2530955f, 0.47626138f, + 0.061956525f, 0.644235f, 0.735041f, 0.83424246f, 0.2638232f, 0.6273294f, 0.7184546f, 0.0766899f, + 0.32245284f, 0.60083544f, 0.906533f, 0.086719275f, 0.31617337f, 0.24053413f, 0.048938513f, 0.8919203f, + 0.22487706f, 0.011788309f, 0.9583977f, 0.69286025f, 0.9183893f, 0.542809f, 0.70638496f, 0.5914087f, + 0.29438752f, 0.67338234f, 0.73173434f, 0.63433063f, 0.4570613f, 0.70197225f, 0.60771453f, 0.7654821f, + 0.13812399f, 0.14224577f, 0.67939556f, 0.9691554f, 0.88740516f, 0.3961745f, 0.57973206f, 0.49563587f, + 0.37618434f, 0.026723444f, 0.97712487f, 0.1978364f, 0.96686924f, 0.58496886f, 0.13263822f, 0.7866956f, + 0.74798125f, 0.0469383f, 0.30301625f, 0.9748982f, 0.65107566f, 0.32560885f, 0.8141865f, 0.3695873f, + 0.3067485f, 0.5446483f, 0.29709607f, 0.5181831f, 0.29501313f, 0.33996922f, 0.11060041f, 0.02524674f, + 0.1647259f, 0.34876132f, 0.031807184f, 0.60586995f, 0.5313419f, 0.79801136f, 0.81332296f, 0.41627264f, + 0.9496114f, 0.466555f, 0.41363865f, 0.9874878f, 0.14229488f, 0.11154431f, 0.091483414f, 0.1414501f, + 0.45834088f, 0.3178513f, 0.68110096f, 0.13568163f, 0.07941705f, 0.53083426f, 0.96346164f, 0.5235164f, + 0.76482266f, 0.21967298f, 0.6893093f, 0.34609234f, 0.22017813f, 0.34257007f, 0.70869184f, 0.3112445f, + 0.6078779f, 0.5715967f, 0.25227106f, 0.94672066f, 0.44809794f, 0.8252871f, 0.88438797f, 0.86481494f, + 0.95148784f, 0.112026334f, 0.4173867f, 0.34240413f, 0.36137718f, 0.28898042f, 0.19898665f, 0.9130488f, + 0.22580266f, 0.40754277f, 0.75443816f, 0.26275212f, 0.9795632f, 0.4115007f, 0.37215167f, 0.6558706f, + 0.5015968f, 0.9426372f, 0.4018826f, 0.8016564f, 0.26275063f, 0.33729053f, 0.29362458f, 0.5511502f, + 0.12605143f, 0.5844644f, 0.30534166f, 0.009582937f, 0.41596514f, 0.7294f, 0.29118127f, 0.14877802f, + 0.77849144f, 0.37462044f, 0.09746289f, 0.612149f, 0.4321279f, 0.13902885f, 0.5097318f, 0.23706353f, + 0.5031309f, 0.35566723f, 0.813157f, 0.43805653f, 0.08011025f, 0.27288425f, 0.36515707f, 0.65959847f, + 0.42534047f, 0.06618434f, 0.89204556f, 0.25902748f, 0.2919191f, 0.77043164f, 0.7297412f, 0.3313396f, + 0.11166459f, 0.6964502f, 0.90752834f, 0.15901726f, 0.900068f, 0.2406702f, 0.6299379f, 0.7841949f, + 0.7123557f, 0.50165176f, 0.50806034f, 0.9854489f, 0.3689273f, 0.19051671f, 0.32111454f, 0.18070674f, + 0.39045036f, 0.6977819f, 0.9975799f, 0.5931799f, 0.24834275f, 0.6194577f, 0.095391035f, 0.7776301f, + 0.20649683f, 0.9747401f, 0.6052985f, 0.4263978f, 0.078885734f, 0.7982801f, 0.7177174f, 0.05027789f, + 0.36251247f, 0.63427305f, 0.64812917f, 0.7067924f, 0.44507074f, 0.23906898f, 0.19735986f, 0.03428924f, + 0.57597995f, 0.7426891f, 0.2980855f, 0.19379246f, 0.10265827f, 0.86077714f, 0.33929902f, 0.53255826f, + 0.9328308f, 0.30511153f, 0.32955307f, 0.39670604f, 0.5471632f, 0.26892018f, 0.5072957f, 0.3091979f, + 0.19757104f, 0.23322952f, 0.26604044f, 0.26614368f, 0.79264474f, 0.8869355f, 0.49160433f, 0.8244896f, + 0.65486836f, 0.92657536f, 0.753946f, 0.45817143f, 0.46995157f, 0.33152503f, 0.56953406f, 0.07021719f, + 0.749946f, 0.50672436f, 0.5824448f, 0.9922104f, 0.9277247f, 0.86002564f, 0.064810574f, 0.9553747f, + 0.41281444f, 0.42781502f, 0.92897725f, 0.7250554f, 0.029657006f, 0.3121612f, 0.77614623f, 0.18359429f, + 0.2813105f, 0.44517577f, 0.68485546f, 0.5000241f, 0.9771502f, 0.81278664f, 0.3692627f, 0.7917771f, + 0.24083132f, 0.19877058f, 0.28815013f, 0.30789846f, 0.24441916f, 0.7504142f, 0.4646262f, 0.5149533f, + 0.6508104f, 0.5387549f, 0.1506182f, 0.77973497f, 0.5613007f, 0.8947809f, 0.24737573f, 0.882782f, + 0.005495131f, 0.90993816f, 0.80128354f, 0.8917693f, 0.091435075f, 0.6704894f, 0.95763975f, 0.27762014f, + 0.5927816f, 0.74048257f, 0.39038336f, 0.7001237f, 0.30437487f, 0.11625093f, 0.3411532f, 0.8124163f, + 0.6896569f, 0.71350795f, 0.6001208f, 0.36836517f, 0.5188247f, 0.01892668f, 0.15339965f, 0.25419557f, + 0.2064318f, 0.88833815f, 0.47727865f, 0.88709056f, 0.49082512f, 0.11388689f, 0.31912774f, 0.9593644f, + 0.78015697f, 0.8311006f, 0.63250524f, 0.6571592f, 0.20667112f, 0.9809466f, 0.14625502f, 0.04218155f, + 0.57540226f, 0.9262546f, 0.44018233f, 0.4769255f, 0.3852706f, 0.09546256f, 0.9905191f, 0.13382024f, + 0.07074398f, 0.16185725f, 0.4233225f, 0.3151368f, 0.9110969f, 0.41447622f, 0.58130866f, 0.7651092f, + 0.9817701f, 0.11023259f, 0.42101204f, 0.5397561f, 0.5778314f, 0.22598988f, 0.31468558f, 0.8214317f, + 0.67929304f, 0.79097277f, 0.7053677f, 0.8750746f, 0.019512117f, 0.032795787f, 0.43307304f, 0.59631824f, + 0.84294224f, 0.93783194f, 0.53943527f, 0.60289854f, 0.8253078f, 0.30536187f, 0.14199954f, 0.9196194f, + 0.9227386f, 0.15846318f, 0.29161406f, 0.7992451f, 0.72343415f, 0.9935814f, 0.6151709f, 0.077278376f, + 0.11056489f, 0.16041255f, 0.99320936f, 0.06393474f, 0.33417904f, 0.46236527f, 0.7078488f, 0.25308424f, + 0.5302883f, 0.99474716f, 0.028710365f, 0.6765068f, 0.998457f, 0.29653466f, 0.08064735f, 0.89677596f, + 0.9564106f, 0.53748304f, 0.79050577f, 0.09722155f, 0.9004303f, 0.3784039f, 0.47111923f, 0.95958996f}; const std::vector fc1_experts_bias = { 0.8748215f, 0.5054756f, 0.74107623f, 0.32518923f, 0.0639081f, 0.62639004f, 0.64906263f, 0.17322052f, 0.7424998f, 0.07288867f, 0.93031204f, 0.9841952f, 0.6361292f, 0.18628561f, 0.7433356f, 0.5852079f, @@ -435,19 +480,8 @@ TEST(MoETest, MoETest_Relu) { 0.012911659f, 0.045757107f, 0.27884653f, 0.3585817f, 0.116771236f, 0.25755364f, 0.23161705f, 0.2906256f, 4.8571277f, 5.649453f, 5.485141f, 5.306299f, 4.767025f, 6.9010167f, 5.3520975f, 6.711155f}; - RunMoETest(input, - router_probs, - fc1_experts_weights, - fc2_experts_weights, - {}, - fc1_experts_bias, - fc2_experts_bias, - output, - num_rows, - num_experts, - hidden_size, - inter_size, - "relu"); + RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, {}, fc1_experts_bias, fc2_experts_bias, + output, num_rows, num_experts, hidden_size, inter_size, "relu"); } TEST(MoETest, MoETest_Mixtral) { @@ -456,10 +490,10 @@ TEST(MoETest, MoETest_Mixtral) { int hidden_size = 4; int inter_size = 8; - const std::vector input = { - 0.9212995f, 0.5282444f, -0.008228387f, -1.449332f, -0.6051824f, -0.17924511f, 0.1995587f, -1.2461947f, - 0.86708033f, 0.19191018f, 1.1600108f, -0.008815222f, 0.8504777f, -0.84964496f, -1.4019964f, 0.17225051f, - 0.35569248f, 1.2056456f, 1.3690308f, -0.69495815f, 1.4324434f, 0.22761835f, -1.1286871f, 1.124213f}; + const std::vector input = {0.9212995f, 0.5282444f, -0.008228387f, -1.449332f, -0.6051824f, -0.17924511f, + 0.1995587f, -1.2461947f, 0.86708033f, 0.19191018f, 1.1600108f, -0.008815222f, + 0.8504777f, -0.84964496f, -1.4019964f, 0.17225051f, 0.35569248f, 1.2056456f, + 1.3690308f, -0.69495815f, 1.4324434f, 0.22761835f, -1.1286871f, 1.124213f}; const std::vector router_probs = { -0.09331456f, -0.47121337f, 0.07311103f, 0.47643483f, 0.21135253f, -0.72226393f, -0.048502743f, 0.39447474f, -0.9014899f, -0.36629856f, -0.23088816f, -0.099606544f, -0.45191774f, -0.30394578f, 0.6266495f, 0.67937183f, @@ -468,125 +502,771 @@ TEST(MoETest, MoETest_Mixtral) { 0.15675665f, -0.4546509f, 0.24447554f, 0.5921611f, -0.18192923f, -0.66116416f, -0.40265432f, 0.33475468f, 1.2906091f, 0.4709078f, 0.16256471f, 0.19308007f, 0.97568524f, 0.25876164f, -0.7964541f, -1.0319631f}; const std::vector fc1_experts_weights = { - 0.3860137f, 0.077925384f, 0.13434184f, 0.28902978f, 0.25391752f, -0.38351142f, 0.15813059f, 0.031481862f, - 0.083209574f, 0.4039817f, -0.13558972f, -0.21858627f, -0.30475253f, 0.41026944f, -0.008697987f, -0.3412701f, - -0.16235226f, 0.054659843f, 0.21042877f, 0.28863233f, -0.49495423f, 0.14401567f, 0.39130414f, 0.154176f, - 0.30897498f, -0.15768659f, 0.44641107f, 0.089463115f, -0.19318026f, 0.20710677f, -0.3552568f, -0.17219114f, - 0.41923493f, -0.4233985f, -0.41503525f, 0.19466156f, -0.08633667f, 0.45547962f, -0.054792404f, 0.26722562f, - -0.09923202f, 0.3460176f, -0.49708033f, -0.41033173f, 0.10443485f, -0.39646107f, -0.37424505f, 0.1757198f, - 0.43019837f, -0.13757241f, 0.14305532f, 0.37121457f, 0.2581259f, 0.12583363f, 0.45542932f, 0.16247797f, - 0.15579104f, -0.19166303f, -0.109221935f, -0.36702687f, 0.40365517f, -0.21506298f, -0.36697525f, -0.2703231f, - -0.49740213f, -0.3486371f, 0.24005288f, -0.0048963428f, 0.20468098f, -0.09111178f, -0.1485982f, -0.088219464f, - 0.33463532f, -0.49346995f, 0.42075223f, -0.38025302f, -0.245484f, -0.35191745f, 0.3086716f, -0.2423737f, - 0.37881732f, -0.40608948f, 0.26193494f, -0.4283861f, -0.10062629f, -0.32670784f, -0.16040438f, -0.15297079f, - 0.1822241f, 0.37285012f, 0.12654608f, -0.46767431f, -0.28775263f, 0.16585541f, -0.36678362f, -0.4759978f, - -0.34751755f, -0.3163945f, -0.3858195f, -0.38030273f, -0.06156373f, -0.04352224f, -0.4041785f, -0.335764f, - -0.10303855f, -0.4009425f, -0.1236487f, -0.40111196f, 0.23985302f, -0.118291676f, -0.26773083f, 0.121197104f, - 0.3702919f, -0.34168184f, 0.33743858f, 0.24873763f, -0.23140603f, -0.25351608f, 0.48291886f, 0.13780516f, - 0.25632292f, -0.49343884f, 0.08369112f, -0.37192065f, -0.05451995f, -0.44571918f, -0.24150735f, 0.27395487f, - -0.20423341f, -0.024149835f, 0.40208143f, -0.18211937f, -0.19767642f, -0.19397742f, -0.1510992f, 0.48074025f, - 0.18377024f, -0.18288034f, 0.08111167f, 0.12729281f, 0.27861303f, 0.0076527f, 0.36356348f, -0.24359548f, - -0.33313757f, -0.374829f, -0.08705664f, 0.23576546f, -0.39819986f, -0.09880793f, -0.012998581f, -0.36475456f, - -0.32685202f, 0.29657948f, -0.4631365f, -0.06320876f, 0.31600899f, 0.060619473f, 0.39029974f, 0.401151f, - 0.15562236f, 0.43565983f, -0.058149397f, 0.36150748f, 0.10750586f, -0.063970566f, -0.47026545f, -0.3035437f, - -0.38143605f, -0.4734699f, 0.31273925f, -0.43410504f, 0.07299572f, 0.47506f, 0.021913886f, -0.036100805f, - -0.31637233f, 0.37718338f, -0.046213806f, 0.19239199f, 0.13676548f, 0.33592474f, -0.34048676f, -0.11097133f, - -0.41569126f, -0.01680845f, 0.31357706f, 0.0943895f, -0.24053341f, -0.018784225f, 0.40659577f, 0.08897692f, - 0.3793823f, -0.3271106f, 0.067666054f, -0.12331611f, -0.010209799f, -0.48908865f, 0.19195485f, -0.45211792f, - 0.48282713f, 0.4363466f, -0.40184838f, -0.025082052f, -0.31057972f, 0.14850605f, 0.39756012f, -0.25782883f, - 0.3181312f, 0.17685872f, -0.16694272f, -0.41516554f, -0.062004805f, -0.33060408f, -0.13665432f, -0.43781847f, - -0.298562f, 0.013283849f, 0.48130906f, -0.27970356f, 0.20347959f, -0.24402553f, -0.20528454f, -0.114435256f, - 0.12556863f, -0.4344011f, 0.2868948f, 0.19894183f, -0.12849897f, -0.18726158f, -0.4850099f, -0.4352169f, - -0.40527463f, 0.13625044f, -0.49707252f, -0.45698053f, 0.28196156f, 0.16826987f, -0.25944453f, 0.2801003f, - 0.21121234f, -0.04066527f, 0.45854944f, -0.17861038f, 0.18178529f, 0.17789757f, 0.34227383f, 0.26976448f, - 0.15789884f, 0.22840887f, 0.419321f, -0.14490443f, 0.39608955f, -0.4162954f, -0.47072983f, 0.41119635f}; + 0.3860137f, 0.083209574f, -0.16235226f, 0.30897498f, 0.077925384f, 0.4039817f, 0.054659843f, + -0.15768659f, 0.13434184f, -0.13558972f, 0.21042877f, 0.44641107f, 0.28902978f, -0.21858627f, + 0.28863233f, 0.089463115f, 0.25391752f, -0.30475253f, -0.49495423f, -0.19318026f, -0.38351142f, + 0.41026944f, 0.14401567f, 0.20710677f, 0.15813059f, -0.008697987f, 0.39130414f, -0.3552568f, + 0.031481862f, -0.3412701f, 0.154176f, -0.17219114f, 0.41923493f, -0.09923202f, 0.43019837f, + 0.15579104f, -0.4233985f, 0.3460176f, -0.13757241f, -0.19166303f, -0.41503525f, -0.49708033f, + 0.14305532f, -0.109221935f, 0.19466156f, -0.41033173f, 0.37121457f, -0.36702687f, -0.08633667f, + 0.10443485f, 0.2581259f, 0.40365517f, 0.45547962f, -0.39646107f, 0.12583363f, -0.21506298f, + -0.054792404f, -0.37424505f, 0.45542932f, -0.36697525f, 0.26722562f, 0.1757198f, 0.16247797f, + -0.2703231f, -0.49740213f, 0.33463532f, 0.37881732f, 0.1822241f, -0.3486371f, -0.49346995f, + -0.40608948f, 0.37285012f, 0.24005288f, 0.42075223f, 0.26193494f, 0.12654608f, -0.0048963428f, + -0.38025302f, -0.4283861f, -0.46767431f, 0.20468098f, -0.245484f, -0.10062629f, -0.28775263f, + -0.09111178f, -0.35191745f, -0.32670784f, 0.16585541f, -0.1485982f, 0.3086716f, -0.16040438f, + -0.36678362f, -0.088219464f, -0.2423737f, -0.15297079f, -0.4759978f, -0.34751755f, -0.10303855f, + 0.3702919f, 0.25632292f, -0.3163945f, -0.4009425f, -0.34168184f, -0.49343884f, -0.3858195f, + -0.1236487f, 0.33743858f, 0.08369112f, -0.38030273f, -0.40111196f, 0.24873763f, -0.37192065f, + -0.06156373f, 0.23985302f, -0.23140603f, -0.05451995f, -0.04352224f, -0.118291676f, -0.25351608f, + -0.44571918f, -0.4041785f, -0.26773083f, 0.48291886f, -0.24150735f, -0.335764f, 0.121197104f, + 0.13780516f, 0.27395487f, -0.20423341f, 0.18377024f, -0.33313757f, -0.32685202f, -0.024149835f, + -0.18288034f, -0.374829f, 0.29657948f, 0.40208143f, 0.08111167f, -0.08705664f, -0.4631365f, + -0.18211937f, 0.12729281f, 0.23576546f, -0.06320876f, -0.19767642f, 0.27861303f, -0.39819986f, + 0.31600899f, -0.19397742f, 0.0076527f, -0.09880793f, 0.060619473f, -0.1510992f, 0.36356348f, + -0.012998581f, 0.39029974f, 0.48074025f, -0.24359548f, -0.36475456f, 0.401151f, 0.15562236f, + -0.38143605f, -0.31637233f, -0.41569126f, 0.43565983f, -0.4734699f, 0.37718338f, -0.01680845f, + -0.058149397f, 0.31273925f, -0.046213806f, 0.31357706f, 0.36150748f, -0.43410504f, 0.19239199f, + 0.0943895f, 0.10750586f, 0.07299572f, 0.13676548f, -0.24053341f, -0.063970566f, 0.47506f, + 0.33592474f, -0.018784225f, -0.47026545f, 0.021913886f, -0.34048676f, 0.40659577f, -0.3035437f, + -0.036100805f, -0.11097133f, 0.08897692f, 0.3793823f, 0.48282713f, 0.3181312f, -0.298562f, + -0.3271106f, 0.4363466f, 0.17685872f, 0.013283849f, 0.067666054f, -0.40184838f, -0.16694272f, + 0.48130906f, -0.12331611f, -0.025082052f, -0.41516554f, -0.27970356f, -0.010209799f, -0.31057972f, + -0.062004805f, 0.20347959f, -0.48908865f, 0.14850605f, -0.33060408f, -0.24402553f, 0.19195485f, + 0.39756012f, -0.13665432f, -0.20528454f, -0.45211792f, -0.25782883f, -0.43781847f, -0.114435256f, + 0.12556863f, -0.40527463f, 0.21121234f, 0.15789884f, -0.4344011f, 0.13625044f, -0.04066527f, + 0.22840887f, 0.2868948f, -0.49707252f, 0.45854944f, 0.419321f, 0.19894183f, -0.45698053f, + -0.17861038f, -0.14490443f, -0.12849897f, 0.28196156f, 0.18178529f, 0.39608955f, -0.18726158f, + 0.16826987f, 0.17789757f, -0.4162954f, -0.4850099f, -0.25944453f, 0.34227383f, -0.47072983f, + -0.4352169f, 0.2801003f, 0.26976448f, 0.41119635f}; const std::vector fc2_experts_weights = { - 0.10833451f, 0.34020698f, -0.18258394f, -0.17842063f, -0.07365984f, -0.29177922f, -0.24102151f, 0.1077901f, - 0.2932343f, -0.35068116f, 0.1875877f, 0.07474385f, -0.20955177f, -0.27660736f, -0.14290786f, -0.09014153f, - -0.21085852f, -0.2378315f, 0.21457997f, 0.21074237f, -0.21087126f, 0.14320332f, -0.08389844f, 0.24034885f, - 0.31800103f, 0.12659892f, 0.20224877f, -0.2563875f, 0.11782206f, 0.29377612f, -0.27469966f, -0.18875091f, - 0.32136288f, 0.0788243f, -0.26413083f, 0.18453442f, 0.0776935f, -0.19561274f, 0.12608862f, 0.18579696f, - 0.045481127f, -0.17894714f, 0.27366453f, 0.13220324f, -0.3115706f, -0.016884197f, -0.3328494f, -0.062126897f, - 0.14841764f, 0.19741052f, 0.08211302f, -0.09362138f, -0.053040292f, -0.090344846f, 0.18264277f, 0.037823465f, - -0.16197139f, -0.20172869f, 0.064109616f, -0.062456656f, 0.30368346f, -0.12107184f, -0.12590908f, -0.10535928f, - 0.1978099f, 0.13119277f, 0.21948591f, -0.080250844f, -0.24614547f, 0.33202717f, 0.2645375f, -0.21193951f, - 0.17770219f, -0.04986229f, 0.33435768f, -0.0309231f, 0.16043694f, -0.0027341924f, -0.08339601f, -0.17402375f, - 0.2525901f, -0.0813988f, -0.2904943f, -0.14452116f, -0.27119386f, -0.2952116f, 0.0794895f, -0.11223866f, - 0.25427446f, 0.16967128f, 0.19531254f, -0.33598322f, -0.16714293f, -0.35097876f, -0.35189477f, 0.2900932f, - 0.26874313f, -0.1322388f, -0.330179f, 0.064027935f, 0.19688474f, -0.20129368f, 0.006225848f, 0.19252343f, - -0.35054854f, -0.31874785f, 0.32238203f, 0.29287276f, 0.03135616f, 0.015792634f, 0.20397249f, -0.3245995f, - 0.21416605f, 0.15667121f, -0.2058509f, 0.23639117f, -0.032677338f, 0.07826358f, -0.04589425f, -0.24935842f, - -0.20834164f, 0.069915086f, -0.26063374f, 0.13239416f, 0.33705652f, -0.26813045f, -0.17056243f, 0.29919288f, - 0.27704936f, -0.096224755f, 0.13250813f, 0.26709175f, -0.26995474f, 0.3261805f, -0.18062393f, -0.04732303f, - -0.02733084f, 0.050550338f, -0.2937818f, -0.19453493f, -0.34864828f, -0.20862648f, -0.19311349f, 0.17665526f, - -0.2894185f, -0.020016002f, 0.3409702f, -0.18320526f, 0.068286195f, 0.08490415f, 0.30223787f, -0.2386011f, - 0.09405743f, 0.123811804f, 0.31660154f, -0.11290163f, 0.07494662f, -0.24999082f, 0.2075398f, 0.07419645f, - 0.3327035f, -0.09647329f, 0.24138254f, -0.32546985f, 0.033594366f, 0.16555631f, 0.33516192f, -0.32619375f, - 0.20476541f, -0.07724f, 0.018923176f, -0.21126744f, 0.2744358f, -0.23979841f, -0.30413106f, -0.3485449f, - 0.2854276f, 0.14391156f, -0.24802732f, -0.21701548f, -0.122100174f, 0.054206114f, -0.21961808f, 0.13481297f, - -0.07907457f, 0.15763119f, -0.31156835f, 0.29488218f, 0.17039073f, 0.35125035f, -0.17721775f, -0.10516899f, - 0.072144486f, -0.038529005f, -0.058253434f, 0.13062657f, -0.3312356f, -0.15963489f, -0.20129326f, 0.014987925f, - 0.30869225f, 0.283981f, -0.057181682f, 0.15174268f, 0.22181617f, -0.19763571f, 0.28675067f, 0.0003976555f, - -0.34610963f, 0.2931936f, -0.26233214f, 0.19563977f, -0.16886877f, 0.022812065f, 0.080249704f, -0.2798801f, - 0.11531327f, 0.07107194f, -0.34746924f, -0.051920194f, -0.07264093f, 0.27581826f, 0.18536879f, 0.15684144f, - -0.26691115f, -0.22811417f, -0.1498502f, -0.176639f, -0.25876564f, -0.16051741f, -0.0048792143f, -0.08490091f, - 0.18136817f, 0.24729891f, 0.32358363f, -0.09566104f, 0.3074607f, -0.24191524f, -0.21220984f, -0.23039621f, - 0.21154472f, -0.19495378f, 0.002779711f, -0.34692943f, 0.055384878f, 0.25809082f, 0.16814983f, 0.19935164f, - 0.11652225f, 0.1115539f, -0.24407779f, 0.09392998f, 0.33556697f, 0.11422251f, 0.34336287f, -0.33113837f}; + 0.10833451f, -0.07365984f, 0.2932343f, -0.20955177f, -0.21085852f, -0.21087126f, 0.31800103f, + 0.11782206f, 0.34020698f, -0.29177922f, -0.35068116f, -0.27660736f, -0.2378315f, 0.14320332f, + 0.12659892f, 0.29377612f, -0.18258394f, -0.24102151f, 0.1875877f, -0.14290786f, 0.21457997f, + -0.08389844f, 0.20224877f, -0.27469966f, -0.17842063f, 0.1077901f, 0.07474385f, -0.09014153f, + 0.21074237f, 0.24034885f, -0.2563875f, -0.18875091f, 0.32136288f, 0.0776935f, 0.045481127f, + -0.3115706f, 0.14841764f, -0.053040292f, -0.16197139f, 0.30368346f, 0.0788243f, -0.19561274f, + -0.17894714f, -0.016884197f, 0.19741052f, -0.090344846f, -0.20172869f, -0.12107184f, -0.26413083f, + 0.12608862f, 0.27366453f, -0.3328494f, 0.08211302f, 0.18264277f, 0.064109616f, -0.12590908f, + 0.18453442f, 0.18579696f, 0.13220324f, -0.062126897f, -0.09362138f, 0.037823465f, -0.062456656f, + -0.10535928f, 0.1978099f, -0.24614547f, 0.17770219f, 0.16043694f, 0.2525901f, -0.27119386f, + 0.25427446f, -0.16714293f, 0.13119277f, 0.33202717f, -0.04986229f, -0.0027341924f, -0.0813988f, + -0.2952116f, 0.16967128f, -0.35097876f, 0.21948591f, 0.2645375f, 0.33435768f, -0.08339601f, + -0.2904943f, 0.0794895f, 0.19531254f, -0.35189477f, -0.080250844f, -0.21193951f, -0.0309231f, + -0.17402375f, -0.14452116f, -0.11223866f, -0.33598322f, 0.2900932f, 0.26874313f, 0.19688474f, + -0.35054854f, 0.03135616f, 0.21416605f, -0.032677338f, -0.20834164f, 0.33705652f, -0.1322388f, + -0.20129368f, -0.31874785f, 0.015792634f, 0.15667121f, 0.07826358f, 0.069915086f, -0.26813045f, + -0.330179f, 0.006225848f, 0.32238203f, 0.20397249f, -0.2058509f, -0.04589425f, -0.26063374f, + -0.17056243f, 0.064027935f, 0.19252343f, 0.29287276f, -0.3245995f, 0.23639117f, -0.24935842f, + 0.13239416f, 0.29919288f, 0.27704936f, -0.26995474f, -0.02733084f, -0.34864828f, -0.2894185f, + 0.068286195f, 0.09405743f, 0.07494662f, -0.096224755f, 0.3261805f, 0.050550338f, -0.20862648f, + -0.020016002f, 0.08490415f, 0.123811804f, -0.24999082f, 0.13250813f, -0.18062393f, -0.2937818f, + -0.19311349f, 0.3409702f, 0.30223787f, 0.31660154f, 0.2075398f, 0.26709175f, -0.04732303f, + -0.19453493f, 0.17665526f, -0.18320526f, -0.2386011f, -0.11290163f, 0.07419645f, 0.3327035f, + 0.033594366f, 0.20476541f, 0.2744358f, 0.2854276f, -0.122100174f, -0.07907457f, 0.17039073f, + -0.09647329f, 0.16555631f, -0.07724f, -0.23979841f, 0.14391156f, 0.054206114f, 0.15763119f, + 0.35125035f, 0.24138254f, 0.33516192f, 0.018923176f, -0.30413106f, -0.24802732f, -0.21961808f, + -0.31156835f, -0.17721775f, -0.32546985f, -0.32619375f, -0.21126744f, -0.3485449f, -0.21701548f, + 0.13481297f, 0.29488218f, -0.10516899f, 0.072144486f, -0.3312356f, 0.30869225f, 0.22181617f, + -0.34610963f, -0.16886877f, 0.11531327f, -0.07264093f, -0.038529005f, -0.15963489f, 0.283981f, + -0.19763571f, 0.2931936f, 0.022812065f, 0.07107194f, 0.27581826f, -0.058253434f, -0.20129326f, + -0.057181682f, 0.28675067f, -0.26233214f, 0.080249704f, -0.34746924f, 0.18536879f, 0.13062657f, + 0.014987925f, 0.15174268f, 0.0003976555f, 0.19563977f, -0.2798801f, -0.051920194f, 0.15684144f, + -0.26691115f, -0.25876564f, 0.18136817f, 0.3074607f, 0.21154472f, 0.055384878f, 0.11652225f, + 0.33556697f, -0.22811417f, -0.16051741f, 0.24729891f, -0.24191524f, -0.19495378f, 0.25809082f, + 0.1115539f, 0.11422251f, -0.1498502f, -0.0048792143f, 0.32358363f, -0.21220984f, 0.002779711f, + 0.16814983f, -0.24407779f, 0.34336287f, -0.176639f, -0.08490091f, -0.09566104f, -0.23039621f, + -0.34692943f, 0.19935164f, 0.09392998f, -0.33113837f}; const std::vector fc3_experts_weights = { - 0.45783097f, -0.2863351f, 0.011728346f, -0.43760604f, 0.15407985f, 0.07818556f, 0.0013856292f, -0.34319758f, - -0.16871625f, 0.12490183f, -0.34154075f, -0.31836903f, -0.46634215f, -0.43996066f, -0.1860516f, -0.2917009f, - -0.1772582f, -0.06599659f, -0.42419833f, 0.49980444f, -0.3283869f, -0.21543652f, -0.034647882f, -0.17114872f, - -0.4837973f, -0.362943f, -0.27533132f, 0.09443748f, -0.16642791f, -0.2993343f, -0.33881485f, -0.39464045f, - 0.31960344f, 0.007296145f, -0.45412838f, -0.024868786f, -0.16298121f, -0.44197202f, 0.07232875f, -0.32362783f, - 0.42969978f, -0.029854119f, -0.18451887f, -0.30145288f, 0.16885209f, -0.30068123f, -0.12948537f, 0.36494362f, - -0.049498677f, 0.12020564f, 0.42106473f, -0.30590254f, 0.31881082f, -0.078908324f, 0.20685762f, -0.22735089f, - -0.11194843f, 0.14011681f, 0.19477749f, -0.44788343f, 0.23084867f, 0.48367476f, -0.19044077f, -0.100233376f, - 0.4191656f, -0.4515314f, -0.3214385f, 0.016065598f, -0.4069137f, -0.17348295f, -0.43329984f, 0.33521235f, - -0.07843453f, -0.4865722f, -0.039011598f, -0.10605621f, 0.4192536f, 0.04063064f, 0.1984514f, 0.49294376f, - -0.056941032f, 0.18582922f, -0.16650558f, -0.17215621f, -0.20009357f, 0.46615022f, 0.47462142f, -0.0766145f, - -0.20405996f, -0.27452308f, -0.16176039f, -0.23940295f, 0.13248974f, 0.23036134f, 0.13154167f, 0.10377723f, - 0.0070211887f, 0.29162645f, 0.34465307f, -0.4058748f, -0.13989884f, -0.12305027f, -0.2541607f, 0.4767149f, - 0.4549045f, -0.108933926f, 0.2452516f, 0.054080307f, 0.33768386f, -0.45279485f, 0.1557768f, 0.17416143f, - -0.42602575f, -0.102350116f, 0.16022503f, 0.14813942f, 0.03982985f, -0.47012872f, -0.14555538f, 0.35645115f, - -0.1909796f, -0.20839584f, -0.28098184f, -0.23085594f, 0.022559166f, -0.23900753f, -0.19561106f, -0.24205637f, - 0.2573983f, -0.2947166f, 0.4568925f, 0.11514187f, 0.18671238f, -0.121082425f, 0.3909887f, -0.10985571f, - -0.19420451f, -0.3255307f, 0.4863913f, 0.007830441f, 0.4648854f, -0.24156213f, 0.22956276f, -0.09216207f, - -0.29428315f, 0.26062596f, 0.14955276f, -0.036366224f, -0.12957954f, 0.08501935f, -0.36796576f, 0.041123867f, - 0.06744653f, -0.0839923f, 0.17207885f, 0.006872058f, -0.21135789f, 0.3732242f, -0.2683524f, -0.45898575f, - -0.14543939f, 0.30806476f, 0.08574325f, 0.027492225f, -0.38164973f, -0.040038824f, -0.26947904f, -0.09740937f, - 0.26697665f, -0.43565083f, 0.1359719f, 0.12271714f, 0.0149876475f, -0.44011843f, 0.26128954f, -0.42487514f, - -0.24668545f, 0.06113738f, -0.29119557f, 0.194273f, -0.24981815f, 0.3489496f, -0.47321397f, -0.31794417f, - -0.23641628f, 0.44169098f, -0.006898284f, 0.43446392f, -0.39553195f, 0.057907403f, -0.19339961f, -0.08160931f, - 0.4979084f, -0.11149913f, 0.35366338f, -0.16032219f, -0.48278677f, 0.08397317f, 0.4008311f, 0.30288273f, - 0.2546957f, -0.10675722f, 0.069722414f, 0.456497f, -0.19691509f, 0.49017924f, 0.41796166f, -0.2337895f, - -0.3635872f, -0.45445484f, -0.29122698f, -0.4339773f, 0.15762383f, 0.09782606f, -0.27986187f, -0.23860168f, - 0.38454843f, -0.07870716f, 0.15390605f, -0.15793777f, 0.48130733f, 0.288768f, 0.45969498f, -0.4193731f, - -0.3218134f, -0.29914904f, -0.3426242f, 0.06931591f, -0.2633695f, -0.25429398f, 0.25366426f, -0.27700734f, - 0.49418402f, -0.21919805f, 0.041192472f, -0.19817531f, -0.49578953f, 0.48185098f, -0.41920406f, -0.08335745f, - 0.19111753f, -0.07547706f, 0.049694f, 0.13012594f, 0.2617172f, -0.22612399f, 0.32247066f, -0.33702326f, - 0.20062232f, -0.09143996f, -0.063310504f, 0.1885702f, 0.11926836f, 0.3378734f, -0.45973647f, 0.48845494f}; - const std::vector output = { - 0.026516449f, 0.04061616f, 0.04403834f, -0.13644142f, 0.038774252f, 0.024002096f, -0.061423667f, 0.034824893f, - -0.022858473f, 0.04693405f, -0.0120724365f, -0.028846134f, -0.0168579f, -0.07958221f, 0.048179876f, 0.053492386f, - -0.026292695f, -0.009724421f, -0.026503641f, 0.031220898f, 0.04189077f, 0.11775493f, -0.037770163f, -0.0790936f}; + 0.45783097f, -0.16871625f, -0.1772582f, -0.4837973f, -0.2863351f, 0.12490183f, -0.06599659f, + -0.362943f, 0.011728346f, -0.34154075f, -0.42419833f, -0.27533132f, -0.43760604f, -0.31836903f, + 0.49980444f, 0.09443748f, 0.15407985f, -0.46634215f, -0.3283869f, -0.16642791f, 0.07818556f, + -0.43996066f, -0.21543652f, -0.2993343f, 0.0013856292f, -0.1860516f, -0.034647882f, -0.33881485f, + -0.34319758f, -0.2917009f, -0.17114872f, -0.39464045f, 0.31960344f, 0.42969978f, -0.049498677f, + -0.11194843f, 0.007296145f, -0.029854119f, 0.12020564f, 0.14011681f, -0.45412838f, -0.18451887f, + 0.42106473f, 0.19477749f, -0.024868786f, -0.30145288f, -0.30590254f, -0.44788343f, -0.16298121f, + 0.16885209f, 0.31881082f, 0.23084867f, -0.44197202f, -0.30068123f, -0.078908324f, 0.48367476f, + 0.07232875f, -0.12948537f, 0.20685762f, -0.19044077f, -0.32362783f, 0.36494362f, -0.22735089f, + -0.100233376f, 0.4191656f, -0.07843453f, -0.056941032f, -0.20405996f, -0.4515314f, -0.4865722f, + 0.18582922f, -0.27452308f, -0.3214385f, -0.039011598f, -0.16650558f, -0.16176039f, 0.016065598f, + -0.10605621f, -0.17215621f, -0.23940295f, -0.4069137f, 0.4192536f, -0.20009357f, 0.13248974f, + -0.17348295f, 0.04063064f, 0.46615022f, 0.23036134f, -0.43329984f, 0.1984514f, 0.47462142f, + 0.13154167f, 0.33521235f, 0.49294376f, -0.0766145f, 0.10377723f, 0.0070211887f, 0.4549045f, + -0.42602575f, -0.1909796f, 0.29162645f, -0.108933926f, -0.102350116f, -0.20839584f, 0.34465307f, + 0.2452516f, 0.16022503f, -0.28098184f, -0.4058748f, 0.054080307f, 0.14813942f, -0.23085594f, + -0.13989884f, 0.33768386f, 0.03982985f, 0.022559166f, -0.12305027f, -0.45279485f, -0.47012872f, + -0.23900753f, -0.2541607f, 0.1557768f, -0.14555538f, -0.19561106f, 0.4767149f, 0.17416143f, + 0.35645115f, -0.24205637f, 0.2573983f, -0.19420451f, -0.29428315f, 0.06744653f, -0.2947166f, + -0.3255307f, 0.26062596f, -0.0839923f, 0.4568925f, 0.4863913f, 0.14955276f, 0.17207885f, + 0.11514187f, 0.007830441f, -0.036366224f, 0.006872058f, 0.18671238f, 0.4648854f, -0.12957954f, + -0.21135789f, -0.121082425f, -0.24156213f, 0.08501935f, 0.3732242f, 0.3909887f, 0.22956276f, + -0.36796576f, -0.2683524f, -0.10985571f, -0.09216207f, 0.041123867f, -0.45898575f, -0.14543939f, + 0.26697665f, -0.24668545f, -0.23641628f, 0.30806476f, -0.43565083f, 0.06113738f, 0.44169098f, + 0.08574325f, 0.1359719f, -0.29119557f, -0.006898284f, 0.027492225f, 0.12271714f, 0.194273f, + 0.43446392f, -0.38164973f, 0.0149876475f, -0.24981815f, -0.39553195f, -0.040038824f, -0.44011843f, + 0.3489496f, 0.057907403f, -0.26947904f, 0.26128954f, -0.47321397f, -0.19339961f, -0.09740937f, + -0.42487514f, -0.31794417f, -0.08160931f, 0.4979084f, 0.2546957f, -0.3635872f, 0.38454843f, + -0.11149913f, -0.10675722f, -0.45445484f, -0.07870716f, 0.35366338f, 0.069722414f, -0.29122698f, + 0.15390605f, -0.16032219f, 0.456497f, -0.4339773f, -0.15793777f, -0.48278677f, -0.19691509f, + 0.15762383f, 0.48130733f, 0.08397317f, 0.49017924f, 0.09782606f, 0.288768f, 0.4008311f, + 0.41796166f, -0.27986187f, 0.45969498f, 0.30288273f, -0.2337895f, -0.23860168f, -0.4193731f, + -0.3218134f, 0.49418402f, 0.19111753f, 0.20062232f, -0.29914904f, -0.21919805f, -0.07547706f, + -0.09143996f, -0.3426242f, 0.041192472f, 0.049694f, -0.063310504f, 0.06931591f, -0.19817531f, + 0.13012594f, 0.1885702f, -0.2633695f, -0.49578953f, 0.2617172f, 0.11926836f, -0.25429398f, + 0.48185098f, -0.22612399f, 0.3378734f, 0.25366426f, -0.41920406f, 0.32247066f, -0.45973647f, + -0.27700734f, -0.08335745f, -0.33702326f, 0.48845494f}; + const std::vector output = {0.026516449f, 0.04061616f, 0.04403834f, -0.13644142f, 0.038774252f, + 0.024002096f, -0.061423667f, 0.034824893f, -0.022858473f, 0.04693405f, + -0.0120724365f, -0.028846134f, -0.0168579f, -0.07958221f, 0.048179876f, + 0.053492386f, -0.026292695f, -0.009724421f, -0.026503641f, 0.031220898f, + 0.04189077f, 0.11775493f, -0.037770163f, -0.0790936f}; - RunMoETest(input, - router_probs, - fc1_experts_weights, - fc2_experts_weights, - fc3_experts_weights, - {}, - {}, - output, - num_rows, - num_experts, - hidden_size, - inter_size, - "silu", - 1, /*normalize_routing_weights*/ + RunMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, {}, {}, output, + num_rows, num_experts, hidden_size, inter_size, "silu", 1, /*normalize_routing_weights*/ 2 /*top_k*/); } +TEST(MoETest, QMoETest_Mixtral_Int4) { + int num_rows = 2; + int num_experts = 2; + int hidden_size = 64; + int inter_size = 64; + + const std::vector input = { + -0.8477f, -0.0746f, 1.606f, -0.3242f, 0.4028f, 0.2384f, -0.0359f, -1.667f, -1.265f, -0.3035f, 0.5327f, + 1.109f, 1.111f, 0.533f, -0.5947f, -0.2009f, 0.4224f, -0.576f, 0.825f, 1.038f, -0.2722f, 0.0497f, + 1.963f, -1.075f, -0.8374f, 1.055f, 0.448f, -0.602f, -0.2874f, -1.311f, -0.0609f, -1.991f, -0.0732f, + -1.49f, 0.6636f, -0.4053f, -1.603f, -1.088f, 0.09534f, -0.6807f, -0.3958f, 1.205f, -0.4275f, 0.82f, + 1.029f, 0.2693f, 1.229f, 1.116f, 0.718f, -0.827f, 2.527f, -1.041f, 1.042f, -2.771f, -0.654f, + 0.7144f, 0.6255f, -0.00957f, -0.2313f, 0.4663f, 2.803f, 0.0655f, 1.232f, 1.557f, -1.238f, -1.337f, + 0.1522f, -0.2783f, 0.2252f, 2.252f, 0.557f, -0.6885f, 1.16f, -0.5244f, -1.424f, -0.02344f, -1.09f, + -0.749f, -1.118f, -2.6f, -1.308f, -0.742f, 0.3064f, 1.892f, 1.573f, -0.3843f, 0.6475f, 0.38f, + -1.423f, -2.04f, 0.005592f, -0.5977f, 1.063f, -1.626f, -0.04883f, -2.041f, -0.502f, -0.8906f, -0.3987f, + 0.387f, 0.4644f, -1.419f, -1.35f, -0.9634f, -0.871f, -0.53f, 0.495f, -0.6157f, 0.6523f, -1.036f, + -1.234f, 0.11566f, 0.2035f, -1.782f, 0.837f, -0.8955f, -1.392f, -0.4f, 0.6533f, -0.289f, -1.328f, + 0.528f, -1.269f, -0.581f, -0.4805f, -1.539f, -0.554f, 0.478f}; + const std::vector router_probs = {-0.579f, -0.07007f, 0.0784f, 0.5327f}; + const std::vector fc1_experts_weights = { + 31, 119, 6, 42, 175, 252, 107, 46, 177, 207, 30, 178, 230, 186, 37, 69, 175, 194, 74, 203, 73, 190, 129, + 112, 203, 106, 103, 156, 52, 121, 95, 101, 29, 149, 95, 107, 247, 189, 182, 92, 136, 49, 56, 227, 58, 71, + 26, 111, 192, 107, 253, 212, 206, 86, 171, 35, 130, 119, 32, 66, 17, 99, 14, 11, 188, 109, 242, 62, 124, + 127, 140, 70, 110, 81, 18, 90, 206, 12, 4, 240, 63, 17, 119, 191, 87, 245, 85, 170, 129, 213, 96, 249, + 91, 127, 68, 172, 34, 23, 102, 76, 222, 244, 156, 71, 35, 55, 34, 166, 194, 164, 32, 193, 94, 144, 60, + 200, 16, 255, 137, 29, 205, 211, 167, 39, 163, 178, 47, 244, 232, 70, 207, 190, 177, 105, 53, 88, 29, 246, + 150, 176, 166, 224, 69, 68, 245, 85, 155, 237, 114, 129, 3, 237, 88, 245, 165, 72, 194, 38, 125, 249, 110, + 235, 242, 116, 151, 65, 48, 196, 129, 79, 170, 63, 186, 95, 42, 173, 252, 235, 245, 207, 163, 38, 185, 69, + 163, 102, 131, 116, 124, 153, 135, 30, 123, 10, 210, 137, 115, 58, 209, 244, 239, 55, 16, 127, 190, 109, 168, + 152, 111, 29, 201, 31, 109, 56, 62, 1, 191, 50, 149, 179, 198, 241, 252, 100, 150, 66, 172, 200, 52, 4, + 132, 82, 141, 103, 161, 17, 185, 62, 255, 121, 158, 184, 31, 243, 93, 103, 243, 91, 99, 16, 143, 74, 27, + 19, 56, 100, 122, 183, 64, 131, 239, 183, 77, 143, 240, 194, 148, 42, 171, 112, 230, 204, 239, 224, 156, 245, + 50, 75, 174, 255, 86, 217, 246, 79, 17, 74, 87, 29, 218, 6, 1, 77, 251, 235, 251, 14, 177, 21, 26, + 219, 87, 203, 108, 241, 50, 69, 66, 104, 236, 111, 47, 51, 77, 8, 233, 220, 187, 162, 191, 15, 35, 196, + 159, 16, 155, 220, 123, 26, 147, 68, 84, 127, 167, 52, 183, 253, 246, 108, 59, 142, 65, 214, 173, 62, 165, + 204, 178, 95, 201, 211, 108, 132, 161, 220, 3, 191, 34, 251, 28, 245, 99, 160, 87, 43, 154, 244, 100, 76, + 164, 28, 159, 155, 110, 137, 90, 109, 217, 125, 145, 143, 215, 238, 127, 182, 165, 31, 181, 172, 6, 63, 149, + 133, 127, 247, 60, 91, 47, 87, 248, 214, 69, 225, 183, 202, 159, 210, 85, 2, 34, 33, 83, 91, 86, 243, + 84, 149, 146, 5, 104, 125, 178, 63, 182, 159, 166, 186, 95, 101, 250, 124, 119, 207, 199, 37, 65, 182, 78, + 218, 164, 193, 71, 19, 73, 88, 88, 25, 80, 60, 109, 161, 23, 78, 139, 248, 243, 122, 201, 53, 67, 68, + 140, 249, 116, 139, 39, 72, 197, 92, 54, 53, 209, 28, 226, 149, 237, 216, 239, 241, 223, 214, 52, 85, 240, + 237, 187, 106, 220, 186, 49, 87, 219, 235, 63, 213, 248, 176, 196, 135, 177, 79, 8, 66, 243, 18, 127, 201, + 16, 167, 252, 26, 95, 225, 154, 210, 202, 182, 227, 232, 249, 135, 47, 151, 254, 169, 34, 31, 159, 29, 233, + 228, 37, 213, 40, 245, 22, 47, 73, 213, 93, 242, 117, 75, 110, 248, 206, 74, 24, 139, 252, 174, 245, 101, + 161, 214, 58, 154, 202, 25, 147, 127, 100, 111, 217, 190, 167, 20, 16, 237, 167, 247, 247, 192, 54, 153, 63, + 187, 178, 64, 182, 209, 247, 218, 211, 45, 242, 132, 94, 137, 238, 184, 240, 250, 42, 193, 150, 139, 159, 242, + 149, 89, 22, 111, 9, 78, 194, 146, 249, 173, 185, 243, 147, 250, 228, 144, 123, 250, 49, 92, 231, 1, 152, + 179, 101, 178, 255, 94, 136, 6, 30, 77, 173, 137, 110, 56, 16, 90, 95, 115, 145, 113, 51, 172, 152, 242, + 119, 7, 186, 149, 168, 213, 228, 229, 133, 31, 40, 189, 19, 74, 88, 75, 134, 255, 17, 116, 208, 224, 242, + 252, 156, 153, 44, 165, 119, 23, 206, 175, 33, 213, 59, 243, 103, 244, 92, 130, 184, 162, 229, 49, 203, 157, + 208, 106, 156, 237, 218, 223, 235, 58, 203, 117, 228, 119, 127, 58, 169, 171, 166, 203, 180, 254, 149, 90, 37, + 117, 63, 98, 107, 130, 143, 179, 72, 168, 184, 61, 137, 185, 123, 57, 70, 90, 115, 78, 26, 36, 167, 237, + 145, 247, 220, 103, 129, 207, 183, 6, 56, 178, 46, 198, 245, 220, 42, 59, 134, 217, 186, 33, 50, 200, 145, + 84, 96, 139, 72, 115, 253, 221, 42, 177, 49, 22, 201, 183, 194, 46, 222, 94, 169, 233, 22, 252, 228, 48, + 91, 72, 201, 150, 203, 219, 247, 30, 117, 19, 131, 74, 235, 214, 223, 221, 244, 36, 170, 94, 84, 253, 52, + 186, 223, 127, 156, 248, 182, 74, 89, 223, 202, 255, 194, 151, 195, 115, 176, 180, 55, 194, 33, 14, 133, 39, + 250, 129, 142, 25, 49, 126, 47, 67, 215, 56, 116, 242, 117, 36, 98, 207, 78, 168, 150, 175, 109, 229, 54, + 45, 221, 205, 130, 52, 133, 208, 174, 234, 234, 188, 71, 250, 3, 43, 225, 57, 144, 225, 157, 202, 251, 194, + 242, 106, 188, 121, 239, 104, 206, 238, 116, 28, 253, 1, 62, 153, 193, 147, 24, 120, 70, 241, 148, 54, 227, + 159, 242, 208, 23, 14, 102, 41, 29, 254, 184, 52, 27, 45, 69, 137, 149, 23, 151, 123, 99, 190, 107, 247, + 81, 164, 57, 163, 77, 213, 250, 203, 235, 134, 85, 83, 1, 188, 70, 100, 70, 56, 217, 34, 55, 103, 123, + 9, 191, 147, 132, 36, 21, 38, 89, 82, 170, 166, 167, 129, 238, 171, 227, 254, 188, 14, 202, 249, 54, 158, + 12, 146, 139, 203, 113, 153, 163, 180, 104, 220, 210, 108, 77, 200, 183, 135, 118, 241, 170, 143, 238, 120, 68, + 215, 67, 133, 182, 222, 155, 12, 254, 157, 250, 87, 190, 226, 141, 62, 250, 212, 43, 218, 109, 229, 79, 213, + 249, 190, 225, 107, 238, 15, 233, 72, 56, 193, 64, 86, 156, 215, 170, 70, 69, 28, 221, 97, 21, 73, 8, + 184, 53, 138, 80, 239, 59, 117, 23, 39, 104, 167, 202, 47, 231, 253, 174, 226, 114, 40, 140, 3, 200, 198, + 134, 85, 10, 104, 50, 1, 75, 243, 1, 248, 148, 143, 180, 213, 47, 198, 76, 78, 12, 15, 128, 23, 144, + 82, 50, 132, 219, 73, 81, 45, 169, 39, 178, 85, 130, 111, 17, 111, 54, 13, 47, 184, 139, 245, 23, 74, + 202, 138, 183, 30, 136, 134, 120, 105, 95, 253, 89, 252, 67, 169, 149, 251, 4, 57, 175, 30, 79, 15, 104, + 34, 65, 252, 101, 211, 6, 81, 72, 40, 18, 240, 163, 255, 53, 132, 173, 124, 251, 145, 69, 98, 46, 137, + 179, 11, 113, 247, 158, 232, 179, 215, 14, 47, 185, 237, 58, 121, 159, 227, 227, 244, 73, 214, 181, 149, 47, + 93, 20, 245, 237, 166, 192, 82, 133, 84, 33, 124, 12, 229, 59, 44, 30, 62, 136, 243, 143, 119, 101, 253, + 36, 159, 245, 87, 175, 250, 223, 86, 185, 146, 51, 39, 248, 128, 212, 189, 62, 190, 232, 134, 148, 88, 158, + 113, 157, 130, 201, 122, 132, 250, 250, 87, 200, 79, 47, 148, 30, 232, 74, 199, 188, 175, 234, 8, 95, 141, + 253, 49, 158, 20, 135, 153, 195, 212, 255, 104, 68, 78, 75, 155, 253, 163, 146, 58, 124, 34, 231, 4, 138, + 241, 19, 36, 25, 114, 230, 167, 147, 72, 69, 242, 130, 46, 228, 130, 210, 149, 74, 248, 79, 251, 31, 180, + 249, 233, 148, 243, 47, 170, 153, 176, 185, 207, 111, 191, 191, 13, 180, 247, 72, 91, 196, 244, 187, 245, 16, + 188, 239, 66, 253, 181, 45, 52, 245, 26, 63, 112, 42, 68, 29, 216, 166, 57, 87, 252, 27, 159, 177, 143, + 142, 34, 244, 38, 214, 144, 243, 72, 165, 69, 30, 241, 164, 126, 164, 228, 142, 184, 251, 172, 51, 49, 202, + 241, 76, 203, 169, 210, 46, 204, 49, 108, 138, 164, 116, 188, 23, 163, 251, 107, 156, 159, 69, 125, 163, 93, + 80, 31, 97, 200, 83, 212, 111, 248, 154, 82, 128, 187, 232, 254, 195, 174, 195, 69, 204, 39, 67, 34, 78, + 253, 107, 2, 219, 245, 222, 22, 200, 83, 39, 220, 206, 47, 76, 118, 63, 67, 142, 24, 115, 97, 56, 233, + 199, 230, 206, 186, 91, 18, 141, 148, 43, 94, 25, 96, 140, 233, 165, 202, 242, 71, 209, 235, 173, 105, 40, + 82, 63, 165, 218, 95, 48, 89, 103, 89, 24, 100, 195, 208, 251, 28, 240, 235, 157, 207, 191, 52, 249, 203, + 134, 40, 153, 219, 102, 29, 32, 196, 71, 149, 124, 189, 183, 14, 244, 90, 123, 152, 206, 232, 70, 138, 111, + 231, 6, 237, 243, 69, 101, 118, 41, 70, 171, 97, 133, 207, 234, 21, 24, 223, 131, 55, 18, 253, 128, 63, + 179, 85, 174, 62, 36, 217, 191, 90, 142, 215, 79, 247, 16, 53, 200, 241, 85, 141, 194, 125, 75, 30, 237, + 4, 255, 247, 80, 69, 127, 204, 238, 106, 255, 98, 24, 206, 110, 57, 103, 90, 175, 100, 253, 151, 206, 90, + 33, 243, 52, 13, 51, 212, 99, 182, 10, 105, 41, 111, 22, 182, 207, 205, 81, 138, 203, 205, 99, 204, 125, + 152, 55, 3, 130, 116, 116, 106, 178, 74, 234, 56, 51, 20, 223, 180, 49, 155, 79, 53, 194, 227, 118, 116, + 165, 100, 196, 255, 91, 74, 250, 19, 177, 79, 188, 141, 91, 149, 65, 28, 76, 53, 185, 123, 164, 127, 74, + 168, 73, 65, 202, 112, 41, 111, 233, 240, 200, 241, 155, 181, 78, 136, 222, 55, 161, 1, 252, 241, 35, 54, + 223, 122, 32, 31, 59, 46, 177, 168, 246, 226, 25, 71, 24, 221, 120, 27, 118, 242, 216, 200, 252, 208, 174, + 165, 169, 177, 84, 249, 19, 236, 146, 17, 15, 29, 117, 31, 24, 67, 208, 105, 117, 100, 109, 19, 34, 159, + 82, 33, 194, 122, 211, 167, 239, 207, 153, 109, 211, 29, 209, 111, 224, 178, 249, 244, 35, 127, 72, 18, 36, + 211, 211, 251, 98, 142, 200, 131, 114, 93, 84, 51, 75, 87, 252, 180, 27, 111, 16, 238, 233, 186, 168, 181, + 99, 31, 43, 101, 0, 127, 202, 206, 219, 93, 42, 206, 142, 209, 28, 58, 252, 168, 119, 101, 236, 157, 53, + 61, 31, 129, 120, 81, 116, 129, 20, 20, 170, 181, 173, 33, 40, 67, 24, 78, 159, 158, 255, 210, 197, 193, + 73, 127, 254, 239, 216, 130, 121, 97, 249, 22, 136, 179, 221, 10, 40, 246, 89, 48, 36, 57, 186, 236, 58, + 110, 45, 217, 173, 46, 210, 225, 223, 235, 91, 55, 79, 82, 162, 17, 22, 223, 11, 1, 189, 11, 180, 47, + 191, 225, 173, 175, 46, 212, 158, 75, 68, 62, 99, 203, 115, 174, 139, 141, 81, 135, 75, 131, 221, 138, 247, + 155, 154, 56, 220, 123, 32, 28, 151, 47, 241, 14, 14, 84, 147, 39, 247, 99, 249, 187, 111, 26, 60, 140, + 194, 5, 130, 101, 114, 74, 229, 127, 52, 252, 233, 79, 13, 213, 46, 196, 171, 215, 250, 166, 165, 37, 171, + 143, 244, 108, 131, 87, 95, 248, 223, 169, 98, 195, 211, 227, 201, 122, 202, 170, 220, 98, 246, 166, 182, 115, + 177, 67, 189, 94, 175, 28, 9, 225, 100, 173, 217, 150, 32, 90, 72, 190, 222, 65, 169, 170, 175, 179, 194, + 42, 102, 122, 218, 174, 217, 177, 200, 249, 111, 195, 246, 172, 60, 60, 160, 113, 205, 27, 87, 233, 189, 95, + 10, 176, 72, 54, 206, 131, 114, 245, 31, 123, 31, 55, 15, 189, 165, 122, 93, 45, 65, 154, 252, 102, 37, + 39, 132, 85, 148, 140, 37, 110, 201, 35, 31, 54, 243, 174, 164, 230, 255, 55, 7, 220, 123, 246, 98, 31, + 12, 137, 63, 95, 44, 255, 165, 241, 78, 225, 172, 72, 119, 209, 63, 209, 116, 7, 152, 47, 38, 191, 210, + 155, 175, 100, 95, 147, 210, 30, 137, 13, 226, 201, 232, 63, 169, 236, 164, 29, 225, 254, 225, 181, 253, 155, + 211, 100, 33, 255, 75, 14, 86, 238, 0, 124, 199, 102, 67, 230, 84, 161, 90, 169, 181, 140, 11, 133, 41, + 201, 177, 136, 234, 155, 31, 67, 254, 100, 21, 224, 17, 164, 218, 193, 147, 191, 64, 21, 125, 43, 63, 221, + 90, 160, 95, 68, 244, 34, 115, 79, 241, 171, 151, 21, 167, 122, 221, 1, 236, 85, 41, 156, 73, 72, 218, + 146, 171, 44, 62, 233, 174, 75, 237, 232, 23, 207, 64, 128, 165, 209, 135, 170, 49, 143, 238, 109, 228, 200, + 67, 114, 135, 72, 51, 253, 107, 74, 202, 58, 160, 186, 194, 63, 97, 175, 241, 180, 58, 187, 19, 212, 175, + 225, 184, 224, 21, 190, 67, 154, 106, 170, 129, 195, 92, 4, 122, 15, 252, 88, 182, 218, 127, 20, 137, 226, + 143, 205, 70, 126, 137, 190, 120, 186, 175, 44, 169, 76, 26, 45, 213, 99, 171, 144, 90, 100, 175, 104, 66, + 159, 155, 218, 31, 54, 242, 98, 50, 204, 82, 65, 8, 138, 143, 109, 222, 40, 65, 244, 4, 165, 153, 195, + 253, 217, 11, 186, 54, 53, 37, 66, 35, 95, 241, 185, 18, 18, 37, 245, 63, 157, 139, 109, 250, 235, 110, + 6, 85, 234, 85, 1, 211, 82, 159, 245, 71, 92, 52, 233, 124, 250, 244, 174, 218, 47, 37, 232, 203, 199, + 243, 175, 244, 166, 44, 108, 93, 79, 51, 241, 241, 64, 124, 61, 69, 246, 241, 140, 255, 79, 235, 58, 196, + 255, 216, 207, 193, 143, 86, 89, 30, 158, 57, 120, 55, 60, 130, 220, 148, 89, 138, 253, 210, 86, 125, 104, + 83, 244, 158, 100, 56, 216, 103, 13, 77, 106, 41, 163, 164, 237, 2, 95, 141, 76, 250, 238, 199, 129, 41, + 69, 128, 127, 87, 22, 124, 27, 95, 172, 233, 47, 29, 198, 57, 234, 251, 207, 107, 215, 49, 147, 20, 116, + 146, 77, 48, 138, 156, 137, 218, 160, 168, 31, 175, 101, 61, 169, 170, 6, 113, 255, 228, 156, 107, 248, 243, + 109, 201, 245, 134, 240, 37, 96, 165, 63, 94, 56, 191, 97, 58, 57, 125, 247, 21, 249, 124, 143, 140, 235, + 234, 45, 188, 27, 239, 149, 235, 122, 76, 82, 124, 242, 163, 65, 95, 1, 64, 227, 26, 55, 211, 151, 138, + 124, 191, 100, 88, 55, 80, 16, 31, 101, 51, 138, 141, 235, 99, 60, 118, 40, 168, 141, 138, 25, 97, 167, + 37, 174, 98, 213, 26, 76, 121, 235, 113, 247, 131, 114, 133, 252, 245, 244, 33, 108, 253, 122, 201, 143, 92, + 140, 29, 201, 197, 144, 65, 189, 251, 147, 211, 43, 234, 170, 44, 101, 36, 40, 29, 19, 216, 134, 82, 71, + 147, 110, 170, 208, 184, 3, 212, 1, 121, 63, 188, 135, 53, 31, 89, 232, 63, 158, 70, 30, 231, 102, 65, + 22, 150, 244, 121, 148, 90, 103, 155, 242, 48, 185, 245, 59, 105, 182, 110, 59, 127, 33, 248, 52, 192, 125, + 218, 50, 169, 252, 255, 122, 71, 158, 255, 158, 178, 65, 164, 223, 137, 182, 78, 253, 118, 66, 130, 238, 90, + 81, 214, 160, 24, 40, 113, 112, 63, 111, 218, 234, 203, 30, 49, 97, 137, 97, 30, 212, 158, 89, 156, 224, + 56, 70, 203, 48, 119, 148, 255, 18, 4, 35, 117, 42, 79, 248, 178, 226, 250, 67, 250, 71, 215, 170, 197, + 176, 124, 207, 83, 34, 72, 73, 46, 133, 1, 15, 117, 18, 225, 95, 60, 186, 169, 53, 176, 136, 63, 104, + 75, 219, 115, 78, 239, 211, 44, 198, 130, 156, 166, 98, 215, 144, 240, 93, 209, 254, 255, 251, 150, 124, 172, + 228, 6, 183, 79, 127, 241, 62, 121, 129, 183, 228, 30, 237, 244, 108, 246, 65, 75, 241, 145, 82, 217, 210, + 122, 79, 79, 244, 167, 167, 119, 20, 72, 202, 139, 57, 212, 141, 246, 239, 50, 204, 179, 122, 156, 146, 216, + 243, 73, 159, 31, 182, 227, 67, 191, 130, 248, 227, 191, 85, 85, 241, 37, 147, 37, 171, 50, 38, 240, 241, + 140, 61, 83, 236, 116, 62, 245, 44, 91, 227, 189, 220, 246, 243, 186, 175, 145, 88, 23, 61, 175, 6, 175, + 148, 255, 118, 53, 237, 38, 109, 110, 254, 191, 236, 58, 3, 73, 194, 159, 83, 80, 145, 251, 129, 101, 216, + 241, 156, 178, 255, 46, 230, 198, 160, 190, 236, 187, 91, 48, 127, 111, 250, 168, 218, 87, 111, 95, 101, 207, + 168, 98, 64, 199, 207, 170, 37, 101, 125, 190, 166, 139, 157, 39, 246, 39, 21, 137, 250, 172, 253, 80, 113, + 5, 18, 155, 92, 18, 180, 111, 239, 165, 246, 91, 87, 57, 223, 206, 62, 245, 227, 121, 84, 180, 162, 25, + 240, 223, 225, 111, 28, 203, 8, 20, 203, 17, 235, 56, 233, 198, 158, 250, 181, 57, 191, 173, 72, 136, 169, + 232, 179, 51, 222, 165, 152, 85, 125, 199, 118, 82, 182, 77, 240, 249, 255, 209, 127, 49, 190, 235, 18, 143, + 125, 52, 60, 12, 111, 95, 115, 117, 111, 220, 90, 159, 147, 26, 152, 84, 137, 70, 8, 182, 12, 249, 252, + 191, 247, 49, 180, 130, 239, 87, 201, 99, 86, 80, 122, 26, 104, 243, 250, 245, 225, 46, 242, 102, 143, 15, + 154, 27, 134, 236, 86, 50, 31, 152, 254, 47, 127, 249, 204, 245, 53, 207, 100, 84, 47, 88, 210, 213, 161, + 227, 71, 182, 30, 141, 127, 78, 83, 33, 213, 131, 33, 169, 222, 202, 80, 179, 164, 227, 222, 218, 65, 254, + 228, 75, 140, 218, 190, 214, 157, 224, 118, 74, 82, 134, 148, 70, 143, 225, 35, 43, 53, 91, 177, 240, 239, + 132, 139, 226, 245, 39, 97, 45, 74, 251, 87, 175, 185, 187, 205, 83, 33, 45, 155, 164, 132, 61, 146, 56, + 229, 175, 110, 94, 204, 177, 111, 126, 206, 105, 67, 215, 205, 253, 237, 46, 89, 116, 14, 35, 92, 123, 43, + 242, 249, 35, 227, 233, 16, 251, 175, 245, 103, 127, 196, 233, 43, 23, 135, 27, 189, 97, 29, 5, 178, 37, + 52, 27, 140, 229, 63, 86, 138, 32, 34, 87, 38, 206, 32, 155, 191, 48, 61, 63, 113, 27, 43, 92, 149, + 143, 47, 137, 116, 150, 127, 49, 82, 25, 22, 99, 7, 61, 255, 246, 203, 83, 132, 166, 44, 75, 211, 131, + 94, 26, 127, 105, 73, 189, 163, 182, 57, 95, 248, 34, 97, 244, 255, 47, 131, 55, 188, 97, 80, 205, 124, + 184, 111, 50, 168, 44, 159, 249, 74, 164, 200, 227, 33, 253, 24, 217, 38, 135, 130, 87, 114, 90, 199, 216, + 27, 127, 252, 246, 134, 80, 49, 147, 186, 122, 107, 70, 249, 30, 95, 72, 34, 205, 229, 95, 79, 142, 71, + 194, 133, 36, 215, 188, 100, 111, 129, 190, 93, 113, 227, 167, 29, 254, 212, 38, 141, 182, 105, 164, 121, 147, + 225, 252, 137, 32, 50, 31, 64, 35, 168, 78, 185, 31, 38, 73, 255, 33, 60, 38, 125, 49, 12, 242, 248, + 94, 104, 114, 78, 255, 101, 76, 158, 95, 13, 212, 169, 242, 116, 250, 154, 184, 86, 68, 172, 166, 95, 182, + 21, 191, 24, 145, 207, 26, 119, 201, 58, 126, 203, 10, 147, 111, 165, 251, 99, 64, 26, 77, 218, 208, 244, + 239, 93, 226, 254, 68, 198, 82, 188, 47, 200, 15, 44, 100, 33, 90, 55, 71, 249, 148, 175, 219, 59, 159, + 191, 127, 8, 125, 206, 27, 67, 136, 234, 89, 12, 203, 163, 52, 78, 167, 144, 68, 112, 127, 176, 235, 113, + 116, 39, 217, 143, 26, 23, 232, 23, 54, 203, 231, 43, 41, 22, 56, 137, 150, 111, 47, 107, 242, 146, 190, + 63, 154, 248, 73, 25, 251, 161, 77, 187, 158, 34, 87, 21, 224, 133, 0, 239, 190, 152, 87, 128, 229, 24, + 130, 40, 95, 111, 241, 251, 251, 12, 26, 254, 81, 111, 47, 91, 137, 22, 61, 202, 85, 175, 174, 169, 196, + 218, 253, 1, 192, 54, 46, 81, 190, 33, 17, 248, 66, 248, 240, 226, 49, 52, 148, 177, 246, 20, 41, 211, + 128, 12, 247, 221, 130, 118, 113, 183, 47, 230, 233, 57, 162, 183, 74, 28, 195, 132, 210, 124, 248, 255, 166, + 51, 207, 135, 157, 115, 248, 238, 126, 195, 175, 177, 36, 68, 225, 216, 205, 214, 200, 243, 248, 22, 231, 138, + 168, 34, 210, 145, 236, 23, 175, 25, 248, 208, 77, 215, 160, 242, 209, 168, 25, 216, 191, 84, 182, 191, 143, + 243, 240, 182, 210, 241, 170, 216, 182, 29, 87, 168, 99, 128, 216, 169, 245, 68, 33, 153, 164, 28, 56, 79, + 157, 145, 17, 158, 124, 171, 255, 225, 225, 252, 43, 61, 249, 165, 40, 197, 111, 7, 8, 166, 211, 241, 84, + 185, 240, 35, 48, 146, 167, 163, 110, 247, 119, 50, 112, 101, 239, 87, 66, 34, 107, 106, 138, 29, 234, 43, + 151, 206, 254, 95, 142, 85, 46, 74, 46, 89, 200, 199, 210, 198, 156, 126, 247, 17, 44, 140, 84, 159, 187, + 213, 202, 151, 46, 234, 165, 78, 247, 210, 250, 228, 223, 85, 27, 137, 127, 131, 91, 42, 124, 97, 237, 139, + 24, 175, 47, 104, 81, 163, 210, 56, 193, 107, 115, 203, 250, 184, 140, 206, 51, 13, 118, 156, 164, 132, 226, + 127, 213, 32, 23, 66, 55, 137, 154, 29, 125, 226, 85, 38, 207, 180, 91, 24, 35, 167, 224, 138, 93, 66, + 40, 85, 185, 43, 88, 184, 142, 37, 3, 234, 220, 165, 127, 219, 130, 48, 49, 175, 55, 79, 196, 50, 197, + 141, 111, 217, 44, 209, 150, 57, 253, 208, 107, 222, 146, 130, 22, 161, 118, 96, 177, 219, 128, 67, 114, 228, + 143, 247, 49, 204, 69, 25, 7, 169, 52, 35, 168, 255, 181, 86, 79, 169, 135, 151, 21, 74, 119, 106, 15, + 26, 68, 16, 118, 42, 93, 207, 189, 59, 138, 190, 206, 204, 123, 53, 197, 83, 61, 149, 123, 54, 26, 247, + 22, 63, 254, 255, 24, 154, 63, 232, 29, 97, 233, 128, 207, 79, 169, 239, 251, 57, 38, 51, 102, 63, 72, + 226, 234, 250, 254, 190, 141, 96, 137, 54, 234, 111, 111, 190, 106, 95, 233, 147, 226, 15, 63, 251, 41, 53, + 58, 132, 97, 73, 38, 80, 213, 162, 250, 125, 179, 154, 27, 229, 171, 159, 237, 145, 128, 207, 211, 74, 134, + 249, 155, 25, 47, 163, 247, 183, 184, 94, 59, 212, 227, 188, 162, 253, 114, 79, 250, 92, 130, 111, 42, 75, + 45, 136, 209, 108, 30, 100, 93, 221, 142, 125, 228, 143, 255, 219, 238, 67, 182, 40, 158, 133, 21, 248, 143, + 195, 220, 193, 96, 111, 120, 208, 6, 249, 42, 225, 30, 159, 223, 116, 203, 35, 243, 9, 147, 251, 29, 217, + 184, 125, 131, 165, 79, 85, 53, 127, 47, 146, 229, 245, 185, 173, 37, 230, 209, 118, 108, 78, 114, 246, 208, + 200, 186, 239, 244, 2, 188, 114, 255, 76, 89, 86, 157, 237, 225, 41, 43, 43, 86, 89, 52, 132, 26, 148, + 253, 31, 185, 174, 248, 131, 92, 34, 173, 11, 72, 198, 145, 215, 180, 106, 173, 44, 99, 46, 116, 31, 51, + 71, 77, 203, 138, 17, 55, 70, 183, 181, 155, 151, 229, 81, 204, 249, 19, 158, 237, 134, 96, 126, 88, 139, + 111, 138, 247, 110, 253, 217, 107, 111, 77, 234, 130, 199, 177, 143, 199, 47, 123, 155, 32, 99, 117, 16, 181, + 155, 87, 166, 244, 42, 36, 251, 108, 15, 45, 108, 33, 85, 97, 117, 203, 142, 45, 102, 225, 249, 255, 191, + 19, 17, 255, 229, 214, 84, 89, 201, 41, 249, 242, 208, 188, 220, 51, 140, 233, 224, 97, 192, 114, 242, 46, + 23, 241, 29, 59, 150, 179, 21, 239, 109, 77, 205, 63, 81, 102, 116, 33, 87, 84, 190, 255, 249, 71, 40, + 97, 117}; + const std::vector fc2_experts_weights = { + 194, 252, 114, 86, 142, 245, 201, 173, 104, 137, 96, 176, 255, 143, 35, 68, 92, 181, 221, 16, 7, 116, 143, + 99, 41, 223, 180, 116, 244, 139, 190, 255, 138, 100, 204, 207, 5, 23, 77, 101, 217, 92, 134, 241, 138, 183, + 146, 41, 171, 194, 248, 212, 175, 20, 12, 202, 83, 43, 255, 239, 87, 77, 70, 168, 40, 227, 76, 38, 121, + 67, 146, 21, 139, 160, 109, 169, 212, 66, 131, 219, 235, 113, 1, 246, 88, 5, 180, 106, 247, 179, 253, 170, + 103, 204, 178, 62, 173, 37, 80, 214, 130, 113, 23, 137, 0, 183, 32, 115, 1, 253, 222, 223, 19, 114, 243, + 57, 208, 115, 251, 89, 66, 63, 81, 9, 83, 94, 128, 146, 23, 85, 208, 96, 142, 87, 190, 157, 83, 0, + 169, 183, 221, 110, 96, 118, 39, 215, 61, 44, 19, 42, 64, 207, 130, 184, 50, 255, 132, 126, 191, 108, 110, + 226, 232, 44, 175, 234, 228, 67, 85, 191, 56, 90, 44, 219, 49, 21, 246, 113, 70, 69, 67, 118, 132, 241, + 118, 75, 185, 159, 254, 44, 183, 104, 40, 86, 18, 95, 106, 13, 65, 216, 245, 242, 46, 230, 223, 59, 167, + 239, 230, 159, 252, 41, 141, 193, 65, 187, 74, 95, 187, 86, 99, 68, 255, 230, 93, 241, 138, 74, 56, 162, + 138, 150, 143, 220, 59, 152, 212, 218, 251, 165, 13, 207, 184, 78, 11, 49, 79, 245, 181, 136, 197, 239, 110, + 7, 251, 90, 78, 151, 23, 74, 27, 159, 97, 82, 79, 181, 228, 111, 136, 86, 227, 86, 170, 147, 172, 245, + 65, 136, 162, 119, 224, 94, 252, 1, 190, 25, 96, 181, 83, 95, 237, 132, 22, 19, 251, 63, 94, 99, 137, + 62, 219, 223, 159, 206, 246, 96, 71, 85, 255, 119, 55, 190, 175, 108, 175, 189, 101, 172, 224, 175, 191, 204, + 137, 52, 46, 107, 6, 177, 160, 96, 253, 131, 123, 215, 95, 107, 220, 123, 143, 9, 251, 101, 87, 241, 163, + 162, 66, 97, 166, 206, 251, 244, 206, 125, 242, 246, 9, 110, 49, 168, 82, 249, 238, 214, 209, 91, 212, 76, + 122, 12, 88, 247, 41, 241, 253, 29, 203, 226, 51, 60, 17, 0, 255, 193, 237, 213, 45, 96, 201, 175, 111, + 148, 209, 250, 140, 237, 197, 55, 248, 158, 236, 245, 78, 111, 27, 103, 235, 191, 23, 250, 216, 95, 2, 45, + 79, 46, 116, 153, 169, 43, 183, 41, 86, 119, 228, 158, 138, 62, 145, 30, 15, 199, 59, 223, 67, 124, 237, + 186, 98, 138, 94, 184, 156, 249, 117, 40, 223, 158, 38, 91, 176, 94, 160, 254, 23, 175, 142, 235, 3, 146, + 251, 255, 205, 42, 122, 255, 26, 149, 136, 207, 79, 31, 218, 150, 248, 161, 14, 220, 55, 76, 135, 3, 195, + 108, 35, 239, 171, 120, 201, 250, 219, 16, 72, 181, 232, 245, 63, 243, 197, 68, 177, 83, 53, 106, 254, 28, + 89, 45, 31, 243, 162, 150, 122, 82, 155, 74, 77, 220, 46, 49, 254, 105, 43, 18, 62, 234, 13, 189, 85, + 189, 222, 206, 63, 99, 171, 247, 242, 136, 166, 167, 31, 166, 174, 85, 232, 25, 7, 179, 91, 218, 212, 129, + 68, 31, 208, 101, 133, 45, 11, 163, 41, 104, 190, 169, 239, 254, 196, 162, 158, 220, 62, 239, 25, 225, 48, + 20, 147, 119, 151, 178, 216, 238, 185, 119, 191, 34, 191, 87, 230, 94, 212, 148, 130, 130, 15, 101, 26, 99, + 130, 17, 252, 229, 56, 247, 175, 12, 40, 12, 224, 210, 247, 23, 152, 63, 67, 131, 220, 86, 217, 69, 116, + 5, 240, 239, 173, 165, 138, 254, 182, 208, 192, 27, 39, 97, 171, 251, 22, 154, 110, 149, 89, 107, 245, 100, + 230, 64, 137, 85, 62, 152, 166, 223, 246, 250, 186, 183, 146, 212, 207, 169, 191, 247, 246, 19, 248, 49, 47, + 180, 213, 25, 108, 253, 153, 189, 149, 193, 191, 87, 44, 88, 63, 62, 254, 166, 55, 191, 207, 81, 143, 161, + 247, 15, 55, 174, 158, 214, 233, 181, 156, 134, 109, 159, 179, 245, 28, 197, 86, 249, 145, 113, 14, 44, 101, + 41, 221, 15, 230, 151, 232, 47, 69, 199, 172, 147, 143, 113, 70, 70, 222, 162, 71, 29, 38, 71, 183, 252, + 22, 233, 175, 247, 29, 95, 118, 153, 17, 83, 31, 125, 137, 238, 255, 47, 92, 104, 211, 106, 104, 102, 39, + 86, 228, 92, 105, 108, 84, 175, 159, 185, 190, 121, 144, 74, 222, 159, 38, 35, 192, 209, 205, 1, 211, 73, + 84, 162, 192, 236, 250, 155, 245, 223, 52, 242, 14, 242, 37, 245, 170, 110, 181, 82, 251, 244, 143, 71, 56, + 231, 86, 236, 55, 130, 21, 232, 156, 40, 96, 113, 17, 253, 35, 168, 107, 247, 254, 253, 147, 252, 85, 222, + 211, 238, 231, 70, 149, 174, 246, 194, 212, 163, 254, 59, 233, 45, 182, 244, 132, 173, 33, 64, 79, 71, 118, + 42, 134, 0, 202, 30, 47, 103, 212, 40, 232, 124, 116, 75, 29, 254, 253, 124, 47, 217, 35, 87, 114, 228, + 75, 47, 54, 212, 222, 225, 72, 99, 109, 198, 88, 63, 1, 98, 76, 242, 18, 193, 24, 47, 1, 103, 153, + 41, 11, 140, 100, 217, 110, 96, 24, 136, 97, 116, 138, 42, 48, 114, 232, 164, 201, 62, 236, 255, 202, 249, + 253, 66, 246, 68, 21, 220, 15, 47, 239, 205, 150, 127, 104, 154, 20, 176, 48, 41, 90, 149, 181, 251, 19, + 232, 41, 64, 55, 91, 155, 23, 33, 100, 102, 238, 79, 19, 169, 49, 152, 105, 98, 137, 250, 166, 38, 33, + 254, 251, 160, 142, 200, 171, 236, 247, 153, 177, 36, 178, 219, 228, 253, 145, 12, 152, 68, 34, 89, 163, 34, + 47, 247, 140, 228, 45, 179, 100, 3, 134, 4, 151, 67, 144, 180, 40, 245, 234, 113, 29, 149, 201, 44, 132, + 86, 90, 228, 121, 64, 66, 78, 194, 200, 235, 92, 133, 3, 157, 126, 152, 191, 85, 222, 85, 168, 83, 202, + 178, 216, 154, 101, 95, 186, 106, 183, 53, 103, 202, 110, 72, 37, 81, 15, 194, 228, 251, 255, 246, 219, 65, + 168, 83, 30, 11, 63, 23, 28, 87, 106, 185, 138, 246, 89, 132, 45, 255, 107, 85, 18, 90, 244, 158, 214, + 76, 81, 90, 249, 151, 58, 232, 91, 231, 99, 214, 141, 165, 69, 140, 234, 193, 72, 73, 158, 23, 38, 244, + 213, 247, 250, 8, 88, 241, 159, 91, 111, 253, 99, 66, 232, 155, 17, 167, 61, 199, 236, 255, 204, 253, 118, + 34, 194, 8, 136, 237, 159, 117, 63, 247, 30, 83, 229, 226, 10, 253, 235, 148, 85, 85, 127, 221, 153, 165, + 245, 141, 239, 227, 216, 24, 91, 50, 246, 171, 215, 113, 184, 17, 153, 164, 60, 29, 105, 190, 123, 143, 111, + 177, 225, 68, 78, 210, 217, 28, 191, 26, 255, 187, 255, 166, 69, 239, 79, 236, 85, 155, 34, 6, 247, 172, + 61, 41, 68, 159, 247, 54, 159, 103, 98, 112, 209, 188, 252, 148, 125, 113, 230, 167, 154, 53, 143, 168, 172, + 66, 253, 52, 209, 191, 253, 238, 9, 46, 240, 23, 175, 148, 85, 54, 116, 55, 214, 253, 152, 175, 6, 98, + 158, 109, 204, 219, 107, 107, 109, 206, 150, 57, 250, 125, 170, 229, 20, 175, 104, 90, 250, 54, 112, 46, 92, + 250, 156, 181, 177, 177, 166, 167, 250, 181, 209, 208, 127, 235, 151, 133, 86, 147, 42, 42, 84, 133, 136, 170, + 202, 172, 28, 115, 160, 106, 251, 79, 205, 177, 67, 169, 69, 184, 109, 207, 164, 38, 247, 16, 245, 251, 129, + 246, 85, 156, 255, 218, 74, 212, 246, 81, 196, 161, 196, 23, 174, 234, 115, 35, 56, 126, 133, 16, 65, 139, + 113, 63, 243, 23, 18, 211, 30, 236, 167, 189, 129, 111, 15, 228, 254, 120, 127, 141, 215, 232, 119, 119, 113, + 127, 254, 111, 241, 86, 14, 154, 50, 201, 146, 87, 25, 239, 39, 24, 232, 94, 185, 226, 130, 95, 0, 127, + 6, 142, 209, 31, 203, 85, 215, 81, 22, 189, 109, 109, 252, 219, 175, 70, 246, 194, 198, 21, 227, 52, 247, + 149, 32, 242, 126, 5, 236, 140, 73, 117, 177, 169, 162, 68, 44, 132, 68, 77, 110, 54, 130, 161, 235, 118, + 154, 250, 47, 75, 249, 8, 169, 195, 231, 244, 192, 33, 155, 133, 27, 103, 173, 213, 231, 168, 44, 225, 30, + 9, 29, 19, 172, 211, 253, 226, 31, 191, 0, 237, 92, 149, 224, 121, 198, 142, 200, 255, 28, 33, 22, 196, + 66, 213, 73, 104, 64, 133, 221, 181, 147, 69, 105, 234, 243, 97, 169, 130, 242, 229, 171, 202, 210, 53, 29, + 97, 103, 20, 53, 224, 37, 74, 155, 98, 73, 255, 124, 188, 118, 166, 188, 103, 236, 19, 39, 106, 205, 254, + 161, 251, 181, 191, 223, 211, 79, 153, 209, 128, 160, 205, 36, 186, 39, 22, 187, 92, 237, 130, 212, 8, 190, + 91, 166, 145, 223, 232, 83, 88, 94, 242, 164, 94, 97, 120, 246, 31, 243, 57, 59, 100, 221, 60, 1, 39, + 178, 111, 61, 50, 115, 127, 72, 103, 226, 15, 221, 46, 237, 91, 111, 230, 43, 181, 183, 23, 146, 24, 67, + 248, 191, 168, 74, 170, 95, 163, 177, 171, 47, 80, 175, 11, 13, 111, 125, 254, 55, 202, 56, 216, 103, 216, + 134, 50, 222, 169, 211, 118, 107, 135, 221, 90, 102, 69, 142, 108, 40, 223, 86, 15, 157, 79, 110, 198, 12, + 70, 54, 184, 118, 159, 87, 150, 78, 25, 254, 251, 196, 205, 153, 157, 43, 17, 54, 214, 210, 174, 22, 238, + 161, 175, 239, 239, 28, 128, 234, 25, 67, 97, 136, 171, 177, 12, 95, 212, 195, 251, 50, 246, 175, 61, 3, + 89, 68, 142, 41, 43, 114, 67, 130, 48, 53, 106, 47, 119, 106, 253, 129, 242, 254, 223, 90, 77, 56, 221, + 31, 140, 194, 65, 80, 73, 228, 3, 94, 181, 72, 206, 223, 81, 171, 103, 203, 225, 152, 9, 204, 170, 247, + 26, 155, 197, 65, 239, 112, 5, 22, 72, 142, 69, 172, 127, 13, 131, 246, 189, 175, 242, 157, 130, 212, 59, + 175, 254, 55, 144, 163, 60, 59, 194, 103, 144, 95, 123, 18, 234, 183, 239, 45, 52, 212, 242, 172, 164, 128, + 50, 69, 183, 85, 18, 116, 120, 242, 151, 179, 4, 246, 161, 74, 69, 170, 39, 16, 84, 192, 27, 178, 122, + 30, 248, 79, 135, 98, 118, 26, 212, 52, 0, 175, 51, 29, 222, 134, 114, 206, 37, 167, 13, 110, 143, 136, + 164, 249, 59, 209, 230, 181, 205, 237, 203, 28, 177, 224, 228, 102, 39, 156, 22, 204, 142, 160, 195, 51, 207, + 56, 255, 63, 24, 123, 153, 92, 90, 244, 117, 187, 115, 21, 85, 129, 166, 122, 31, 19, 214, 173, 76, 255, + 204, 133, 127, 106, 94, 133, 225, 222, 97, 100, 170, 103, 11, 227, 2, 144, 174, 103, 177, 168, 133, 19, 83, + 255, 100, 55, 128, 230, 142, 163, 136, 206, 45, 122, 78, 99, 236, 150, 133, 34, 71, 93, 245, 246, 251, 149, + 65, 14, 240, 34, 234, 100, 250, 142, 80, 39, 97, 61, 76, 75, 85, 49, 113, 29, 99, 223, 156, 243, 90, + 240, 121, 208, 51, 204, 72, 17, 255, 105, 228, 128, 249, 147, 253, 127, 160, 38, 207, 22, 79, 137, 226, 207, + 218, 213, 166, 154, 177, 4, 204, 235, 182, 59, 245, 206, 213, 244, 135, 71, 25, 92, 216, 142, 81, 129, 255, + 100, 99, 36, 38, 79, 47, 168, 252, 7, 146, 46, 191, 72, 131, 83, 98, 121, 94, 178, 31, 41, 229, 220, + 210, 255, 209, 16, 40, 159, 126, 189, 250, 10, 89, 245, 39, 222, 208, 124, 201, 127, 84, 69, 125, 186, 65, + 250, 52, 79, 165, 45, 232, 234, 250, 162, 85, 220, 186, 165, 229, 136, 255, 72, 34, 47, 239, 15, 213, 129, + 25, 247, 146, 116, 222, 242, 57, 229, 28, 34, 180, 222, 205, 162, 167, 237, 191, 141, 142, 164, 253, 44, 254, + 44, 183, 245, 250, 227, 65, 192, 165, 238, 39, 33, 210, 236, 204, 139, 95, 88, 140, 115, 116, 46, 131, 119, + 2, 141, 243, 37, 247, 49, 36, 56, 63, 218, 242, 28, 149, 100, 49, 87, 183, 210, 250, 168, 148, 244, 117, + 68, 159, 53, 247, 76, 57, 82, 245, 63, 105, 3, 236, 94, 205, 232, 36, 248, 156, 254, 185, 7, 139, 218, + 99, 207, 166, 54, 148, 52, 97, 102, 62, 230, 143, 115, 170, 72, 140, 38, 202, 147, 49, 242, 245, 165, 219, + 158, 240, 111, 177, 193, 81, 176, 181, 14, 18, 110, 70, 55, 99, 80, 249, 127, 161, 29, 51, 220, 143, 44, + 204, 79, 168, 245, 244, 58, 254, 108, 184, 34, 198, 152, 196, 218, 239, 245, 61, 138, 145, 221, 65, 4, 68, + 104, 6, 39, 40, 241, 99, 104, 203, 108, 255, 127, 123, 246, 118, 93, 230, 70, 242, 114, 117, 189, 127, 84, + 233, 2, 183, 88, 47, 243, 217, 98, 89, 225, 89, 31, 138, 2, 221, 5, 79, 247, 63, 120, 99, 71, 184, + 62, 131, 142, 217, 209, 215, 251, 125, 126, 54, 109, 89, 193, 234, 51, 215, 72, 85, 140, 245, 39, 243, 24, + 70, 4, 18, 120, 151, 166, 170, 154, 105, 160, 65, 185, 132, 227, 238, 52, 188, 219, 71, 233, 130, 91, 93, + 239, 218, 72, 195, 199, 157, 199, 211, 98, 202, 118, 163, 28, 178, 64, 181, 169, 2, 191, 74, 161, 254, 19, + 33, 154, 163, 191, 187, 158, 232, 245, 255, 250, 205, 62, 70, 111, 63, 158, 94, 136, 253, 46, 220, 41, 174, + 183, 211, 91, 15, 136, 2, 75, 124, 246, 236, 204, 196, 66, 238, 79, 122, 225, 219, 108, 228, 156, 172, 39, + 107, 163, 191, 196, 182, 29, 69, 147, 250, 36, 42, 188, 116, 196, 163, 146, 161, 255, 228, 4, 125, 253, 124, + 72, 248, 19, 108, 223, 250, 103, 242, 78, 245, 90, 53, 58, 16, 42, 31, 177, 173, 245, 126, 84, 52, 55, + 243, 148, 239, 57, 126, 220, 252, 172, 72, 175, 65, 119, 38, 188, 178, 7, 180, 76, 95, 31, 230, 253, 123, + 81, 187, 184, 39, 111, 218, 94, 251, 187, 175, 24, 158, 90, 128, 130, 145, 33, 86, 46, 143, 33, 175, 134, + 215, 99, 54, 214, 70, 28, 204, 205, 238, 252, 172, 166, 210, 42, 11, 213, 67, 143, 48, 145, 152, 19, 147, + 18, 31, 17, 174, 233, 186, 132, 215, 24, 205, 87, 250, 100, 128, 249, 20, 66, 3, 79, 50, 88, 42, 145, + 156, 45, 246, 169, 230, 163, 130, 13, 164, 127, 76, 150, 60, 101, 209, 112, 34, 247, 117, 197, 23, 114, 164, + 27, 142, 56, 227, 79, 49, 159, 21, 177, 124, 212, 89, 250, 113, 6, 16, 159, 33, 231, 238, 80, 245, 195, + 207, 248, 205, 253, 145, 154, 220, 96, 164, 140, 231, 89, 161, 6, 213, 163, 30, 123, 232, 102, 188, 217, 159, + 128, 16, 29, 27, 28, 199, 114, 212, 95, 111, 255, 109, 215, 5, 101, 191, 216, 134, 194, 108, 98, 167, 61, + 218, 49, 159, 10, 95, 159, 158, 166, 37, 129, 62, 151, 144, 55, 36, 54, 107, 156, 216, 219, 203, 138, 36, + 250, 220, 215, 237, 226, 88, 159, 136, 107, 117, 153, 239, 155, 18, 156, 69, 49, 2, 118, 142, 241, 41, 168, + 56, 155, 250, 105, 112, 217, 189, 214, 164, 162, 244, 52, 95, 36, 35, 191, 66, 51, 75, 192, 171, 103, 118, + 93, 237, 124, 70, 141, 10, 62, 84, 150, 120, 228, 85, 156, 140, 144, 68, 189, 239, 222, 203, 119, 197, 132, + 209, 162, 135, 179, 163, 212, 117, 79, 47, 104, 239, 19, 41, 166, 95, 129, 187, 67, 44, 159, 141, 167, 248, + 43, 179, 42, 15, 94, 249, 234, 104, 118, 225, 254, 236, 204, 92, 13, 26, 204, 141, 195, 227, 2, 104, 251, + 22, 146, 225, 179, 250, 49, 224, 65, 253, 5, 64, 245, 122, 155, 253, 156, 100, 114, 239, 111, 97, 199, 251, + 133, 253, 25, 27, 143, 251, 212, 169, 66, 111, 131, 58, 23, 117, 31, 123, 61, 148, 166, 127, 179, 37, 163, + 3, 82, 94, 113, 41, 80, 56, 207, 67, 173, 41, 28, 249, 157, 10, 42, 98, 136, 201, 213, 127, 102, 34, + 44, 228, 230, 63, 23, 158, 15, 75, 69, 241, 72, 131, 6, 199, 253, 165, 175, 217, 39, 101, 29, 84, 213, + 51, 132, 21, 195, 127, 246, 254, 191, 165, 96, 71, 190, 91, 124, 252, 76, 97, 188, 153, 47, 210, 193, 105, + 71, 206, 209, 146, 102, 254, 251, 43, 152, 117, 188, 195, 196, 88, 246, 223, 102, 100, 159, 7, 252, 43, 240, + 167, 106, 228, 175, 127, 237, 159, 249, 211, 203, 193, 191, 135, 109, 12, 190, 241, 67, 93, 130, 72, 251, 84, + 77, 243, 169, 112, 243, 219, 129, 23, 91, 251, 43, 188, 141, 32, 170, 91, 173, 231, 60, 198, 42, 159, 117, + 218, 197, 63, 226, 96, 127, 201, 179, 206, 186, 144, 219, 84, 145, 216, 51, 47, 2, 113, 241, 68, 149, 98, + 200, 203, 232, 43, 239, 21, 129, 255, 213, 175, 119, 246, 229, 22, 159, 49, 54, 246, 93, 2, 130, 116, 49, + 212, 38, 51, 244, 99, 63, 24, 220, 4, 60, 33, 250, 190, 121, 197, 254, 50, 120, 181, 65, 124, 0, 30, + 249, 209, 209, 124, 136, 241, 134, 4, 116, 174, 87, 11, 113, 177, 239, 82, 66, 125, 215, 84, 105, 153, 227, + 166, 132, 214, 76, 209, 163, 69, 58, 111, 167, 253, 217, 27, 20, 189, 196, 49, 180, 181, 0, 78, 99, 103, + 175, 67, 39, 202, 161, 48, 125, 182, 236, 196, 98, 69, 93, 246, 174, 179, 195, 35, 233, 196, 105, 95, 127, + 25, 7, 153, 69, 42, 94, 19, 113, 40, 171, 113, 33, 191, 176, 215, 71, 230, 134, 187, 213, 233, 107, 225, + 254, 33, 249, 248, 165, 74, 227, 56, 153, 191, 195, 249, 57, 223, 184, 218, 58, 246, 83, 149, 108, 135, 21, + 187, 33, 12, 179, 72, 152, 59, 2, 92, 18, 244, 185, 76, 243, 166, 247, 159, 185, 169, 116, 174, 67, 98, + 241, 248, 137, 239, 17, 0, 224, 0, 191, 230, 117, 248, 58, 170, 27, 251, 202, 176, 148, 57, 2, 220, 130, + 86, 242, 217, 193, 120, 89, 173, 143, 55, 5, 94, 41, 63, 75, 33, 221, 126, 16, 95, 166, 209, 253, 151, + 131, 17, 245, 250, 24, 129, 69, 161, 105, 67, 100, 222, 34, 67, 226, 68, 20, 151, 12, 117, 213, 29, 193, + 241, 90, 238, 15, 162, 60, 146, 66, 155, 52, 196, 197, 65, 138, 25, 44, 216, 56, 241, 191, 200, 76, 86, + 47, 214, 250, 46, 177, 66, 186, 21, 98, 46, 176, 216, 221, 140, 172, 42, 124, 105, 253, 239, 125, 13, 29, + 147, 84, 165, 142, 87, 5, 35, 116, 212, 17, 54, 193, 35, 92, 16, 88, 172, 184, 172, 39, 245, 149, 60, + 115, 76, 95, 185, 254, 16, 177, 116, 160, 191, 43, 191, 84, 204, 56, 77, 109, 91, 224, 120, 244, 95, 165, + 156, 109, 255, 226, 223, 142, 170, 220, 55, 99, 112, 27, 108, 223, 185, 107, 155, 104, 191, 78, 70, 81, 19, + 40, 242, 10, 56, 244, 79, 245, 204, 23, 33, 161, 108, 63, 239, 46, 174, 214, 105, 235, 180, 148, 88, 84, + 30, 27, 164, 84, 58, 30, 73, 220, 135, 108, 3, 117, 204, 51, 24, 236, 192, 115, 254, 74, 123, 58, 28, + 200, 228, 192, 61, 229, 214, 34, 51, 251, 89, 95, 240, 67, 162, 83, 90, 159, 79, 26, 115, 109, 129, 142, + 218, 249, 19, 164, 245, 60, 244, 193, 2, 136, 211, 95, 73, 194, 17, 44, 204, 251, 223, 92, 146, 213, 187, + 84, 49, 120, 216, 49, 63, 102, 160, 85, 185, 148, 233, 207, 135, 254, 245, 236, 4, 21, 61, 177, 109, 24, + 151, 195, 36, 109, 159, 20, 86, 153, 79, 159, 237, 176, 176, 223, 127, 134, 143, 150, 65, 243, 222, 249, 52, + 254, 244, 23, 19, 149, 180, 211, 181, 147, 100, 106, 175, 230, 90, 147, 239, 232, 253, 178, 56, 147, 97, 210, + 177, 108, 185, 98, 207, 98, 102, 18, 90, 161, 240, 94, 85, 165, 127, 226, 74, 46, 249, 21, 19, 165, 20, + 51, 43, 10, 209, 210, 23, 111, 215, 31, 98, 239, 191, 175, 86, 217, 76, 113, 61, 26, 7, 251, 102, 7, + 76, 75, 107, 72, 233, 115, 123, 178, 36, 23, 29, 187, 131, 216, 36, 9, 251, 79, 47, 93, 105, 210, 236, + 102, 228, 110, 156, 202, 228, 73, 95, 119, 118, 126, 247, 248, 118, 250, 227, 29, 113, 89, 66, 115, 154, 217, + 12, 104, 183, 138, 158, 252, 87, 156, 43, 159, 59, 19, 205, 128, 91, 41, 113, 100, 167, 253, 95, 44, 235, + 60, 51, 223, 27, 57, 40, 196, 182, 60, 16, 167, 3, 136, 249, 80, 187, 250, 94, 39, 231, 15, 181, 181, + 181, 97, 214, 211, 145, 119, 234, 78, 24, 75, 190, 30, 218, 221, 215, 211, 93, 54, 173, 14, 28, 159, 69, + 73, 108, 46, 255, 186, 178, 20, 246, 189, 23, 45, 65, 132, 145, 98, 54, 220, 199, 136, 181, 64, 205, 240, + 72, 105, 157, 91, 89, 120, 93, 231, 20, 39, 216, 92, 109, 95, 177, 134, 57, 15, 126, 12, 218, 17, 58, + 210, 252, 71, 28, 28, 128, 147, 70, 217, 27, 69, 76, 185, 250, 25, 18, 46, 109, 127, 161, 211, 68, 202, + 251, 5, 153, 176, 102, 129, 20, 183, 249, 50, 61, 241, 252, 139, 220, 52, 235, 254, 188, 186, 230, 227, 214, + 225, 116, 64, 241, 251, 102, 213, 191, 115, 183, 46, 59, 79, 168, 93, 49, 53, 207, 242, 249, 40, 214, 197, + 193, 49, 210, 201, 63, 182, 136, 247, 29, 72, 113, 255, 164, 204, 109, 141, 120, 3, 73, 129, 230, 214, 107, + 190, 246, 94, 13, 157, 202, 235, 83, 239, 106, 101, 143, 82, 207, 107, 28, 66, 153, 191, 236, 11, 133, 242, + 125, 163, 39, 49, 69, 49, 222, 129, 241, 215, 170, 85, 177, 93, 103, 215, 221, 88, 31, 191, 61, 221, 105, + 117, 55, 26, 41, 27, 197, 220, 160, 188, 36, 162, 144, 43, 239, 28, 121, 250, 43, 148, 78, 244, 136, 110, + 26, 73, 81, 243, 183, 25, 19, 101, 94, 8, 211, 45, 225, 81, 247, 255, 236, 111, 19, 230, 60, 132, 4, + 182, 154, 154, 96, 185, 131, 51, 191, 58, 78, 135, 170, 201, 210, 81, 135, 94, 250, 49, 47, 180, 5, 37, + 145, 237, 223, 111, 40, 24, 255, 39, 101, 76, 94, 125, 106, 113, 33, 20, 254, 69, 180, 67, 137, 223, 61, + 128, 78, 34, 48, 75, 108, 31, 207, 180, 254, 205, 40, 182, 205, 233, 35, 149, 76, 37, 179, 134, 43, 184, + 219, 25, 203, 175, 188, 83, 202, 58, 90, 254, 39, 139, 203, 118, 159, 143, 140, 77, 159, 107, 104, 168, 203, + 44, 18, 191, 164, 39, 62, 150, 71, 52, 173, 13, 212, 244, 190, 57, 68, 119, 235, 226, 12, 87, 102, 41, + 156, 68, 47, 36, 39, 188, 231, 97, 65, 72, 129, 223, 19, 145, 227, 191, 14, 254, 213, 118, 119, 101, 214, + 182, 121, 2, 170, 21, 9, 14, 28, 255, 9, 205, 182, 69, 58, 172, 245, 120, 237, 11, 53, 49, 90, 79, + 235, 241, 95, 166, 90, 215, 119, 77, 111, 145, 13, 157, 33, 57, 39, 252, 127, 2, 153, 151, 12, 110, 81, + 136, 107, 239, 74, 38, 156, 179, 53, 89, 106, 169, 158, 72, 190, 38, 125, 187, 253, 113, 212, 139, 191, 147, + 155, 65, 159, 249, 251, 63, 67, 201, 173, 24, 76, 207, 101, 109, 31, 163, 136, 252, 202, 37, 52, 126, 120, + 81, 225}; + const std::vector fc3_experts_weights = { + 123, 186, 42, 165, 140, 44, 223, 124, 136, 165, 213, 231, 84, 236, 233, 49, 19, 130, 60, 166, 63, 110, 31, + 193, 171, 238, 175, 234, 180, 113, 216, 26, 159, 185, 138, 72, 239, 132, 203, 94, 120, 242, 252, 33, 253, 154, + 247, 128, 96, 218, 30, 47, 218, 24, 82, 105, 64, 231, 188, 35, 255, 89, 198, 116, 84, 157, 44, 160, 175, + 53, 93, 90, 166, 198, 49, 152, 113, 238, 86, 121, 159, 234, 205, 166, 70, 130, 244, 136, 58, 124, 65, 174, + 228, 142, 107, 216, 135, 94, 196, 77, 42, 195, 99, 2, 61, 31, 31, 170, 41, 108, 54, 169, 204, 106, 67, + 204, 147, 111, 35, 27, 113, 55, 167, 76, 95, 191, 248, 208, 231, 172, 193, 31, 121, 140, 250, 187, 117, 93, + 27, 82, 229, 56, 251, 164, 33, 238, 22, 124, 152, 80, 154, 78, 241, 217, 250, 54, 229, 38, 183, 63, 74, + 161, 62, 69, 90, 232, 59, 49, 162, 140, 130, 210, 131, 253, 179, 159, 87, 120, 231, 123, 96, 166, 105, 95, + 18, 220, 249, 7, 150, 79, 29, 206, 112, 20, 181, 3, 52, 171, 124, 187, 88, 136, 98, 25, 94, 47, 212, + 53, 29, 239, 154, 164, 212, 61, 208, 171, 253, 243, 216, 71, 207, 231, 49, 174, 29, 244, 43, 168, 77, 203, + 129, 233, 123, 131, 33, 254, 151, 138, 110, 245, 89, 173, 220, 53, 60, 80, 147, 195, 255, 42, 8, 18, 139, + 97, 146, 248, 54, 220, 163, 22, 242, 58, 198, 57, 126, 122, 203, 118, 149, 13, 145, 193, 154, 237, 207, 99, + 82, 122, 2, 207, 26, 88, 62, 71, 173, 148, 251, 231, 41, 35, 168, 110, 226, 13, 109, 159, 216, 39, 100, + 43, 65, 167, 20, 255, 15, 37, 110, 220, 54, 175, 40, 88, 49, 5, 189, 146, 218, 110, 207, 118, 213, 76, + 191, 145, 12, 214, 238, 32, 123, 62, 58, 106, 223, 63, 24, 122, 127, 192, 189, 215, 129, 2, 212, 242, 207, + 161, 183, 246, 181, 194, 106, 153, 134, 105, 255, 242, 131, 252, 110, 153, 79, 240, 153, 191, 197, 205, 161, 136, + 55, 243, 141, 212, 116, 169, 162, 14, 45, 222, 22, 54, 142, 136, 247, 194, 72, 198, 20, 219, 125, 238, 45, + 255, 217, 134, 19, 234, 36, 57, 49, 108, 38, 54, 166, 101, 200, 31, 64, 90, 147, 144, 71, 41, 138, 12, + 78, 52, 66, 241, 216, 30, 187, 179, 205, 162, 63, 60, 236, 223, 20, 247, 199, 235, 194, 94, 223, 239, 127, + 169, 251, 238, 199, 172, 206, 164, 227, 220, 243, 105, 125, 113, 116, 245, 247, 194, 249, 28, 174, 151, 28, 114, + 47, 205, 148, 241, 95, 55, 55, 197, 188, 152, 96, 85, 190, 114, 108, 144, 226, 237, 86, 25, 130, 200, 17, + 199, 242, 100, 248, 154, 140, 44, 154, 209, 155, 27, 249, 26, 226, 214, 211, 172, 184, 122, 61, 17, 247, 255, + 14, 111, 39, 28, 63, 224, 27, 132, 42, 79, 45, 171, 155, 35, 230, 163, 215, 225, 175, 198, 184, 110, 21, + 247, 251, 83, 77, 223, 230, 135, 216, 199, 119, 131, 71, 121, 129, 185, 244, 155, 125, 205, 17, 249, 114, 133, + 79, 117, 53, 115, 65, 178, 226, 113, 69, 96, 246, 8, 16, 223, 210, 108, 68, 61, 89, 170, 3, 133, 11, + 197, 243, 89, 15, 201, 85, 125, 219, 193, 47, 106, 248, 52, 145, 191, 176, 207, 159, 219, 43, 242, 52, 250, + 250, 31, 248, 131, 175, 249, 43, 250, 239, 101, 45, 114, 62, 95, 167, 237, 190, 169, 109, 76, 119, 193, 229, + 157, 148, 90, 62, 23, 207, 40, 136, 131, 133, 119, 84, 101, 20, 217, 28, 144, 237, 169, 75, 100, 104, 110, + 113, 1, 101, 153, 253, 243, 13, 43, 171, 90, 255, 170, 204, 118, 251, 129, 253, 239, 233, 122, 134, 239, 225, + 167, 176, 94, 50, 76, 164, 234, 125, 109, 238, 114, 164, 164, 49, 163, 125, 193, 38, 32, 193, 242, 119, 38, + 116, 84, 189, 242, 161, 159, 179, 191, 253, 14, 43, 179, 37, 215, 181, 230, 94, 44, 140, 61, 82, 182, 210, + 221, 248, 104, 224, 249, 18, 126, 36, 40, 119, 72, 49, 171, 203, 86, 208, 15, 15, 201, 18, 164, 157, 171, + 155, 225, 176, 137, 228, 246, 228, 160, 159, 116, 245, 94, 195, 219, 197, 207, 254, 188, 135, 104, 64, 223, 233, + 212, 137, 31, 115, 115, 202, 40, 247, 111, 39, 142, 213, 171, 229, 10, 92, 234, 56, 50, 192, 69, 9, 253, + 208, 111, 36, 110, 137, 128, 94, 136, 191, 243, 41, 197, 61, 102, 46, 220, 39, 45, 155, 112, 110, 47, 84, + 21, 43, 68, 172, 154, 1, 76, 26, 110, 94, 181, 58, 63, 173, 127, 12, 19, 124, 6, 212, 182, 76, 21, + 209, 235, 123, 48, 151, 172, 189, 241, 240, 147, 144, 147, 216, 236, 53, 56, 252, 241, 123, 63, 25, 202, 131, + 40, 163, 119, 254, 98, 86, 245, 86, 229, 62, 42, 175, 76, 226, 8, 246, 251, 17, 76, 204, 59, 132, 73, + 196, 194, 88, 51, 117, 19, 146, 154, 90, 79, 72, 207, 177, 182, 99, 185, 113, 179, 91, 130, 156, 47, 209, + 187, 191, 130, 62, 22, 66, 189, 174, 137, 176, 174, 253, 161, 142, 35, 67, 243, 1, 96, 158, 67, 36, 255, + 223, 214, 181, 200, 234, 230, 127, 52, 155, 241, 6, 100, 56, 227, 28, 172, 149, 47, 63, 107, 40, 35, 173, + 174, 76, 173, 64, 78, 101, 103, 104, 150, 95, 240, 148, 226, 120, 56, 72, 118, 183, 142, 40, 246, 75, 34, + 37, 95, 230, 179, 96, 99, 245, 140, 64, 87, 103, 161, 158, 130, 87, 164, 11, 249, 254, 65, 14, 82, 60, + 179, 94, 242, 43, 238, 167, 108, 171, 21, 128, 238, 98, 136, 133, 196, 209, 111, 34, 202, 98, 90, 104, 9, + 222, 99, 70, 246, 47, 77, 219, 35, 28, 125, 79, 231, 54, 31, 207, 119, 159, 95, 239, 54, 3, 39, 18, + 5, 96, 101, 166, 21, 15, 138, 198, 44, 154, 148, 50, 219, 73, 218, 240, 78, 76, 176, 252, 37, 221, 230, + 17, 67, 113, 11, 155, 201, 7, 53, 250, 188, 223, 102, 79, 217, 24, 170, 148, 25, 82, 11, 45, 75, 211, + 72, 10, 166, 78, 79, 17, 152, 143, 58, 118, 135, 90, 156, 24, 157, 253, 205, 60, 241, 8, 251, 77, 172, + 173, 148, 172, 66, 19, 59, 126, 50, 87, 172, 251, 25, 2, 212, 111, 215, 239, 28, 63, 79, 179, 167, 20, + 18, 222, 66, 167, 253, 226, 115, 132, 81, 156, 61, 225, 188, 178, 225, 185, 9, 27, 242, 0, 196, 191, 122, + 133, 8, 101, 21, 250, 28, 133, 39, 69, 126, 149, 43, 239, 152, 127, 155, 2, 231, 72, 156, 169, 20, 171, + 210, 230, 14, 116, 110, 128, 245, 224, 89, 153, 49, 63, 90, 52, 75, 106, 77, 188, 26, 186, 120, 146, 223, + 223, 158, 252, 93, 231, 155, 50, 99, 131, 42, 9, 88, 131, 247, 41, 245, 255, 77, 26, 151, 121, 79, 47, + 250, 41, 29, 120, 133, 198, 177, 202, 100, 80, 185, 112, 134, 179, 39, 164, 190, 68, 72, 48, 104, 253, 117, + 223, 190, 244, 134, 232, 153, 147, 56, 242, 197, 75, 77, 69, 148, 161, 155, 101, 173, 224, 66, 154, 129, 126, + 4, 188, 79, 90, 119, 146, 255, 140, 202, 133, 153, 190, 5, 251, 228, 174, 183, 164, 171, 251, 67, 209, 255, + 20, 67, 86, 226, 209, 178, 7, 87, 222, 69, 203, 57, 225, 24, 172, 250, 127, 241, 114, 215, 38, 28, 209, + 130, 89, 108, 28, 28, 138, 196, 195, 127, 56, 164, 178, 206, 236, 146, 29, 190, 129, 213, 193, 222, 84, 41, + 148, 249, 112, 172, 201, 0, 119, 252, 182, 238, 23, 210, 63, 94, 217, 146, 222, 238, 33, 51, 64, 80, 138, + 218, 136, 244, 105, 22, 126, 205, 221, 143, 237, 111, 152, 218, 223, 126, 229, 178, 201, 202, 84, 244, 33, 234, + 89, 196, 147, 77, 51, 156, 28, 50, 10, 154, 6, 245, 214, 69, 131, 141, 100, 86, 179, 157, 124, 104, 179, + 48, 174, 183, 189, 21, 98, 55, 87, 117, 119, 200, 16, 233, 204, 151, 169, 119, 236, 151, 8, 202, 202, 255, + 99, 41, 33, 207, 124, 239, 212, 147, 235, 18, 129, 37, 125, 151, 58, 79, 26, 75, 13, 169, 205, 167, 161, + 235, 76, 41, 29, 203, 4, 26, 25, 232, 17, 41, 68, 111, 147, 17, 31, 178, 111, 220, 148, 42, 136, 79, + 141, 186, 138, 191, 35, 61, 248, 121, 130, 218, 139, 234, 154, 56, 10, 188, 194, 33, 138, 246, 59, 13, 179, + 27, 78, 23, 174, 61, 78, 46, 74, 132, 100, 127, 78, 207, 97, 167, 82, 249, 208, 84, 65, 87, 74, 9, + 210, 195, 244, 249, 208, 223, 89, 246, 229, 62, 35, 74, 253, 82, 115, 50, 76, 111, 139, 21, 249, 165, 205, + 48, 54, 31, 159, 53, 74, 37, 173, 232, 62, 184, 63, 109, 49, 221, 193, 196, 139, 214, 212, 139, 245, 23, + 211, 243, 17, 135, 235, 31, 179, 184, 152, 210, 202, 245, 85, 31, 228, 193, 202, 234, 14, 133, 94, 118, 255, + 68, 210, 46, 204, 199, 106, 16, 25, 19, 126, 112, 104, 220, 241, 218, 67, 216, 219, 40, 192, 159, 205, 49, + 108, 58, 60, 25, 63, 184, 247, 213, 211, 54, 198, 46, 103, 186, 155, 224, 159, 246, 131, 190, 215, 221, 194, + 234, 26, 201, 71, 170, 40, 185, 236, 169, 210, 66, 249, 78, 245, 67, 35, 221, 252, 180, 25, 43, 200, 53, + 250, 108, 207, 118, 135, 99, 116, 213, 153, 255, 22, 94, 248, 58, 204, 4, 2, 190, 208, 191, 130, 87, 156, + 2, 174, 15, 248, 164, 159, 41, 39, 29, 47, 42, 102, 248, 116, 59, 77, 228, 157, 61, 121, 4, 163, 165, + 33, 156, 242, 247, 45, 31, 51, 170, 23, 183, 252, 245, 124, 4, 87, 103, 144, 118, 182, 237, 159, 140, 242, + 190, 131, 126, 16, 179, 105, 31, 72, 254, 243, 25, 207, 45, 194, 234, 241, 55, 18, 69, 118, 30, 1, 252, + 40, 164, 231, 225, 6, 23, 104, 157, 51, 249, 247, 4, 208, 149, 17, 58, 180, 248, 215, 140, 236, 178, 21, + 133, 110, 155, 79, 245, 69, 35, 255, 189, 245, 87, 216, 123, 3, 155, 202, 253, 32, 237, 154, 120, 20, 232, + 47, 178, 109, 200, 177, 43, 8, 97, 82, 115, 166, 106, 161, 120, 28, 44, 227, 84, 165, 86, 229, 168, 9, + 234, 233, 80, 215, 118, 220, 176, 138, 218, 127, 251, 169, 236, 121, 215, 98, 72, 27, 221, 203, 67, 59, 194, + 79, 167, 118, 50, 98, 141, 162, 224, 181, 124, 57, 57, 191, 230, 201, 213, 15, 84, 6, 28, 112, 228, 53, + 196, 61, 143, 154, 249, 110, 47, 236, 33, 191, 95, 102, 22, 189, 73, 108, 112, 122, 23, 147, 216, 229, 147, + 255, 63, 41, 93, 129, 131, 251, 88, 168, 75, 39, 252, 249, 227, 52, 175, 93, 254, 96, 196, 121, 155, 36, + 95, 252, 88, 9, 74, 50, 254, 40, 64, 75, 121, 88, 185, 98, 15, 51, 87, 163, 253, 122, 132, 35, 196, + 194, 250, 163, 100, 70, 55, 39, 184, 4, 171, 216, 95, 204, 243, 111, 47, 254, 95, 47, 90, 182, 101, 46, + 140, 56, 97, 201, 83, 225, 128, 147, 66, 29, 222, 54, 133, 44, 97, 249, 177, 222, 158, 76, 59, 164, 195, + 230, 151, 58, 68, 117, 55, 8, 94, 107, 233, 48, 49, 214, 230, 114, 239, 48, 97, 92, 5, 101, 95, 58, + 245, 213, 148, 33, 38, 49, 232, 92, 24, 110, 188, 149, 243, 8, 152, 252, 202, 254, 203, 220, 72, 198, 176, + 218, 156, 63, 227, 106, 145, 178, 255, 86, 33, 191, 117, 175, 161, 249, 78, 146, 56, 136, 217, 158, 87, 98, + 212, 131, 81, 58, 6, 70, 59, 239, 155, 247, 169, 63, 92, 5, 228, 162, 69, 40, 221, 95, 111, 13, 182, + 79, 180, 10, 165, 242, 161, 149, 88, 246, 201, 96, 107, 89, 250, 220, 212, 253, 10, 108, 31, 167, 130, 126, + 119, 96, 225, 70, 149, 46, 151, 131, 84, 246, 188, 184, 146, 36, 160, 72, 194, 214, 161, 223, 235, 222, 233, + 243, 70, 158, 131, 103, 22, 120, 58, 89, 190, 17, 187, 92, 104, 191, 103, 187, 218, 244, 111, 246, 178, 73, + 95, 188, 254, 52, 116, 31, 195, 66, 148, 54, 231, 109, 151, 251, 35, 100, 146, 49, 96, 194, 213, 250, 143, + 95, 111, 193, 114, 212, 250, 225, 46, 249, 179, 211, 75, 149, 221, 133, 74, 138, 230, 61, 87, 215, 106, 199, + 246, 239, 31, 63, 81, 172, 247, 206, 87, 118, 1, 38, 125, 196, 78, 138, 99, 58, 81, 157, 82, 252, 59, + 118, 117, 83, 172, 39, 222, 163, 181, 121, 204, 142, 99, 101, 248, 55, 17, 182, 75, 71, 170, 77, 70, 154, + 242, 159, 178, 243, 201, 235, 165, 129, 127, 149, 158, 20, 52, 193, 8, 231, 210, 53, 47, 47, 220, 127, 101, + 243, 220, 219, 188, 221, 166, 84, 173, 140, 111, 106, 42, 88, 15, 200, 59, 248, 214, 246, 202, 242, 77, 238, + 210, 113, 10, 217, 241, 191, 201, 132, 122, 248, 173, 236, 214, 18, 201, 205, 198, 218, 36, 127, 95, 56, 251, + 233, 169, 218, 173, 243, 251, 84, 38, 133, 178, 108, 150, 181, 244, 78, 143, 34, 167, 87, 12, 255, 138, 242, + 194, 174, 198, 243, 100, 250, 227, 136, 35, 222, 53, 144, 249, 60, 105, 28, 111, 0, 53, 46, 239, 81, 21, + 137, 143, 84, 4, 91, 52, 233, 158, 55, 181, 38, 125, 25, 20, 198, 103, 17, 190, 248, 20, 186, 22, 86, + 165, 42, 175, 85, 85, 75, 171, 41, 51, 14, 207, 23, 250, 196, 249, 230, 50, 118, 29, 255, 191, 3, 55, + 58, 237, 106, 172, 15, 75, 117, 89, 122, 108, 248, 227, 79, 106, 245, 61, 211, 190, 49, 165, 202, 29, 82, + 94, 6, 141, 157, 173, 71, 69, 64, 69, 57, 181, 106, 234, 9, 69, 15, 203, 135, 60, 143, 197, 117, 117, + 175, 202, 251, 115, 238, 130, 72, 111, 93, 161, 138, 251, 23, 78, 205, 1, 211, 30, 22, 239, 25, 162, 140, + 10, 245, 22, 172, 181, 244, 50, 3, 85, 101, 120, 101, 68, 67, 70, 180, 195, 73, 121, 91, 138, 136, 241, + 248, 246, 98, 252, 169, 57, 159, 255, 223, 50, 78, 76, 205, 164, 136, 180, 44, 135, 175, 81, 215, 120, 154, + 38, 200, 129, 173, 201, 8, 129, 153, 247, 223, 172, 143, 80, 207, 197, 47, 84, 140, 253, 248, 103, 95, 156, + 230, 161, 60, 71, 95, 207, 101, 134, 247, 244, 196, 212, 140, 30, 143, 60, 115, 145, 181, 89, 105, 61, 146, + 175, 136, 244, 84, 223, 22, 180, 140, 25, 143, 76, 36, 58, 127, 46, 106, 169, 94, 45, 240, 44, 26, 83, + 48, 253, 35, 237, 214, 251, 249, 233, 172, 102, 175, 93, 182, 170, 227, 221, 49, 23, 202, 231, 44, 247, 4, + 188, 212, 182, 132, 73, 159, 179, 216, 180, 127, 45, 145, 104, 99, 63, 228, 185, 239, 176, 24, 69, 78, 174, + 213, 230, 142, 172, 217, 197, 219, 103, 133, 76, 50, 192, 28, 30, 218, 40, 15, 141, 191, 173, 10, 93, 196, + 231, 205, 17, 60, 117, 219, 13, 255, 194, 190, 159, 77, 24, 116, 41, 206, 245, 224, 91, 255, 162, 171, 89, + 106, 168, 18, 18, 140, 89, 190, 38, 120, 107, 200, 236, 225, 170, 35, 81, 50, 243, 111, 124, 255, 83, 95, + 175, 242, 219, 166, 102, 254, 175, 176, 194, 142, 110, 15, 63, 59, 37, 241, 172, 113, 29, 120, 63, 68, 128, + 186, 187, 186, 197, 177, 214, 80, 204, 68, 250, 169, 188, 162, 47, 244, 27, 229, 247, 108, 96, 124, 100, 70, + 103, 205, 126, 239, 229, 155, 255, 221, 54, 15, 168, 96, 137, 149, 15, 209, 250, 54, 248, 246, 198, 105, 118, + 244, 46, 90, 111, 159, 15, 153, 211, 206, 199, 251, 96, 240, 55, 75, 191, 228, 152, 228, 178, 162, 143, 15, + 247, 81, 144, 18, 178, 145, 107, 182, 216, 90, 255, 149, 126, 246, 178, 212, 1, 71, 105, 114, 123, 223, 69, + 171, 91, 140, 137, 133, 143, 44, 107, 123, 80, 132, 240, 75, 175, 43, 102, 89, 98, 174, 187, 153, 250, 191, + 99, 182, 248, 69, 196, 142, 123, 25, 248, 37, 221, 127, 68, 110, 101, 134, 196, 64, 35, 223, 173, 217, 79, + 234, 96, 255, 134, 65, 253, 133, 50, 164, 119, 99, 109, 175, 251, 158, 242, 247, 223, 68, 177, 239, 57, 199, + 236, 54, 140, 254, 113, 101, 195, 41, 60, 38, 187, 210, 252, 83, 178, 85, 143, 135, 12, 148, 90, 115, 73, + 207, 165, 247, 178, 53, 159, 176, 219, 145, 177, 21, 112, 107, 170, 150, 91, 196, 27, 128, 210, 219, 65, 206, + 182, 214, 188, 253, 126, 94, 245, 158, 223, 151, 81, 232, 194, 157, 166, 147, 253, 161, 69, 77, 254, 4, 59, + 210, 70, 31, 118, 108, 68, 253, 200, 212, 52, 159, 68, 118, 190, 191, 157, 218, 110, 103, 247, 207, 183, 23, + 223, 223, 206, 105, 105, 95, 22, 83, 90, 117, 242, 123, 120, 91, 253, 63, 74, 109, 125, 207, 127, 154, 240, + 242, 17, 167, 246, 107, 59, 226, 135, 72, 196, 28, 197, 153, 150, 221, 53, 251, 60, 193, 20, 176, 81, 177, + 178, 116, 175, 106, 24, 95, 255, 5, 253, 95, 187, 51, 220, 252, 94, 45, 16, 188, 108, 132, 231, 136, 10, + 162, 230, 36, 138, 212, 111, 217, 53, 172, 97, 241, 123, 113, 181, 128, 159, 189, 71, 32, 46, 32, 246, 31, + 245, 138, 84, 120, 198, 62, 86, 218, 244, 121, 111, 179, 46, 244, 1, 120, 214, 209, 138, 249, 146, 40, 31, + 193, 19, 239, 37, 150, 12, 111, 82, 197, 55, 234, 31, 121, 213, 148, 90, 114, 253, 33, 242, 235, 138, 143, + 87, 230, 42, 173, 239, 104, 161, 115, 237, 212, 223, 67, 79, 250, 219, 249, 116, 65, 25, 105, 52, 116, 51, + 229, 214, 197, 34, 80, 117, 131, 219, 129, 251, 169, 247, 43, 46, 255, 43, 171, 41, 229, 43, 175, 205, 23, + 219, 169, 79, 166, 58, 81, 147, 240, 195, 138, 32, 138, 71, 59, 188, 232, 228, 253, 252, 182, 39, 183, 62, + 243, 24, 107, 212, 142, 152, 168, 63, 227, 81, 254, 105, 239, 203, 232, 74, 154, 73, 101, 106, 235, 247, 42, + 211, 17, 65, 231, 97, 128, 71, 218, 148, 196, 223, 155, 242, 122, 131, 41, 59, 136, 130, 42, 104, 250, 175, + 238, 42, 149, 245, 143, 252, 65, 243, 60, 170, 132, 221, 126, 251, 135, 102, 120, 113, 241, 249, 220, 37, 248, + 166, 19, 77, 30, 55, 242, 61, 191, 217, 211, 43, 15, 218, 194, 151, 255, 153, 33, 203, 97, 77, 9, 179, + 55, 140, 185, 90, 62, 163, 86, 171, 173, 118, 173, 247, 151, 89, 48, 86, 195, 131, 42, 8, 47, 254, 97, + 38, 182, 80, 53, 111, 93, 147, 111, 240, 124, 102, 94, 127, 141, 247, 86, 11, 37, 11, 73, 218, 224, 62, + 75, 235, 197, 123, 98, 62, 127, 246, 27, 1, 132, 244, 217, 26, 182, 110, 220, 209, 153, 188, 207, 118, 72, + 109, 16, 23, 45, 244, 86, 240, 47, 49, 169, 233, 49, 151, 242, 76, 216, 202, 243, 101, 16, 176, 196, 51, + 225, 204, 68, 217, 185, 59, 42, 36, 255, 117, 191, 219, 145, 31, 220, 37, 17, 237, 35, 7, 194, 57, 168, + 163, 175, 3, 93, 54, 12, 203, 153, 35, 116, 191, 244, 196, 132, 226, 217, 171, 25, 92, 207, 51, 250, 245, + 217, 152, 244, 71, 34, 74, 194, 176, 127, 253, 83, 41, 48, 199, 82, 72, 244, 228, 136, 146, 145, 179, 33, + 103, 202, 149, 92, 159, 26, 233, 154, 133, 183, 194, 36, 2, 206, 31, 209, 131, 230, 109, 0, 109, 92, 49, + 150, 19, 112, 19, 68, 126, 120, 207, 138, 93, 11, 22, 197, 140, 188, 85, 173, 117, 194, 200, 31, 221, 152, + 138, 251, 192, 185, 47, 116, 118, 13, 189, 34, 142, 103, 47, 69, 151, 134, 122, 13, 34, 66, 211, 25, 180, + 151, 227, 104, 201, 175, 96, 248, 114, 140, 90, 145, 203, 90, 171, 129, 187, 219, 162, 245, 119, 107, 84, 171, + 255, 218, 26, 46, 132, 203, 79, 191, 87, 10, 124, 85, 164, 105, 115, 82, 158, 175, 182, 105, 246, 29, 97, + 86, 2, 246, 234, 81, 29, 36, 68, 197, 37, 172, 30, 45, 44, 111, 44, 249, 63, 146, 187, 14, 224, 254, + 52, 218, 75, 54, 185, 165, 92, 214, 70, 247, 253, 79, 232, 112, 230, 42, 58, 53, 15, 98, 135, 139, 76, + 47, 244, 68, 7, 48, 75, 211, 87, 107, 203, 94, 67, 35, 241, 248, 146, 213, 53, 47, 98, 45, 173, 206, + 82, 65, 229, 182, 96, 175, 234, 252, 131, 78, 120, 104, 105, 150, 119, 123, 122, 61, 65, 185, 39, 199, 163, + 127, 193, 222, 149, 184, 233, 75, 98, 170, 150, 196, 226, 235, 242, 26, 192, 17, 84, 218, 119, 62, 58, 59, + 164, 80, 54, 117, 164, 212, 6, 219, 212, 203, 175, 242, 162, 87, 170, 19, 90, 212, 37, 108, 141, 83, 157, + 77, 155, 25, 236, 8, 216, 21, 6, 95, 88, 239, 19, 31, 188, 103, 29, 186, 92, 129, 6, 76, 254, 130, + 242, 209, 181, 53, 29, 251, 164, 165, 111, 47, 89, 122, 196, 103, 252, 112, 127, 133, 76, 110, 172, 82, 119, + 201, 222, 85, 188, 238, 131, 110, 187, 204, 175, 185, 68, 85, 31, 222, 35, 119, 138, 88, 223, 71, 154, 152, + 251, 161, 216, 192, 252, 133, 120, 177, 212, 172, 147, 122, 184, 154, 42, 179, 56, 255, 191, 235, 172, 134, 60, + 98, 187, 243, 199, 253, 250, 84, 72, 228, 121, 168, 20, 168, 99, 120, 93, 242, 177, 110, 195, 143, 217, 54, + 104, 223, 14, 11, 95, 226, 43, 129, 23, 135, 152, 95, 147, 3, 131, 18, 127, 187, 67, 250, 240, 206, 40, + 140, 218, 67, 161, 252, 85, 12, 27, 193, 43, 242, 106, 220, 158, 164, 222, 236, 198, 240, 202, 201, 108, 38, + 52, 96, 238, 24, 175, 118, 65, 194, 116, 236, 248, 29, 1, 106, 176, 24, 229, 192, 172, 95, 83, 180, 253, + 254, 245, 105, 88, 154, 53, 183, 41, 57, 245, 127, 98, 27, 250, 251, 207, 48, 148, 251, 89, 12, 227, 214, + 141, 37, 133, 160, 251, 112, 230, 191, 106, 244, 74, 98, 178, 221, 118, 102, 127, 7, 27, 179, 110, 61, 252, + 133, 107, 97, 23, 98, 244, 211, 98, 89, 191, 44, 170, 197, 43, 240, 29, 199, 248, 210, 239, 148, 106, 209, + 195, 252, 178, 79, 140, 234, 75, 108, 78, 194, 175, 251, 246, 146, 26, 242, 212, 60, 235, 225, 76, 140, 68, + 185, 219, 190, 137, 159, 32, 237, 188, 101, 65, 177, 28, 238, 152, 161, 137, 117, 245, 3, 149, 126, 114, 199, + 39, 49, 255, 13, 15, 53, 186, 74, 245, 25, 245, 197, 251, 129, 47, 25, 153, 198, 133, 226, 167, 88, 94, + 245, 245, 74, 129, 255, 233, 121, 145, 219, 243, 157, 239, 152, 121, 161, 190, 223, 197, 240, 230, 55, 25, 246, + 156, 255, 197, 160, 239, 136, 214, 13, 203, 163, 208, 79, 246, 181, 213, 167, 56, 104, 245, 33, 48, 191, 251, + 33, 127, 100, 71, 66, 54, 104, 224, 85, 34, 255, 52, 247, 83, 68, 227, 120, 232, 117, 105, 66, 237, 217, + 169, 175, 191, 17, 72, 214, 5, 99, 191, 227, 121, 171, 67, 226, 190, 150, 152, 81, 255, 3, 156, 119, 228, + 98, 215}; + const std::vector fc1_scales = { + 0.01553376f, 0.015543817f, 0.015551699f, 0.015492203f, 0.015023133f, 0.0154082235f, 0.0155232195f, + 0.01528402f, 0.015559638f, 0.015533516f, 0.015493423f, 0.015256615f, 0.0152339935f, 0.015549371f, + 0.015381575f, 0.015576782f, 0.015412793f, 0.015498972f, 0.0151363555f, 0.015505189f, 0.014904913f, + 0.015218727f, 0.015376769f, 0.015279377f, 0.015432924f, 0.015483502f, 0.015457189f, 0.015407557f, + 0.0156120695f, 0.014825948f, 0.015501786f, 0.015303297f, 0.015532501f, 0.0152144935f, 0.015333908f, + 0.01479763f, 0.015206473f, 0.01543629f, 0.015437368f, 0.01513233f, 0.015589874f, 0.015567031f, + 0.015393224f, 0.014935784f, 0.015579218f, 0.015432265f, 0.015484579f, 0.015261326f, 0.015371274f, + 0.015189547f, 0.015558099f, 0.014714118f, 0.015086958f, 0.015577158f, 0.014815275f, 0.01525769f, + 0.015569633f, 0.014951542f, 0.015491992f, 0.015379513f, 0.015588352f, 0.015455488f, 0.015094815f, + 0.015585413f, 0.0151954815f, 0.015539678f, 0.015179157f, 0.015570812f, 0.015453467f, 0.015222808f, + 0.015130177f, 0.015514964f, 0.015050512f, 0.013596393f, 0.015181009f, 0.014813691f, 0.015430912f, + 0.015623035f, 0.015465939f, 0.0155621655f, 0.015619047f, 0.015616288f, 0.015411615f, 0.015294425f, + 0.015334727f, 0.01536013f, 0.015485667f, 0.015279645f, 0.015232291f, 0.015200818f, 0.014945071f, + 0.015612004f, 0.015533011f, 0.01562017f, 0.015604494f, 0.015526485f, 0.014934285f, 0.015624931f, + 0.015617797f, 0.0155350845f, 0.015362147f, 0.015408119f, 0.01547795f, 0.014903402f, 0.0154722165f, + 0.015608951f, 0.015536772f, 0.015497636f, 0.01543246f, 0.015433108f, 0.015222307f, 0.0156019665f, + 0.0154854f, 0.014986996f, 0.015555747f, 0.015378246f, 0.015050007f, 0.015395556f, 0.0154241435f, + 0.015317103f, 0.015418313f, 0.015221456f, 0.015339879f, 0.015616156f, 0.01556934f, 0.015396217f, + 0.015617745f, 0.015584825f}; + const std::vector fc2_scales = { + 0.015234984f, 0.015523607f, 0.015164727f, 0.01548125f, 0.015093872f, 0.015315635f, 0.015266418f, + 0.015527874f, 0.015592782f, 0.015093137f, 0.014813861f, 0.015202709f, 0.0153913535f, 0.01537223f, + 0.015511734f, 0.015440272f, 0.015092988f, 0.015597204f, 0.015287647f, 0.015497316f, 0.015502119f, + 0.015546441f, 0.015100006f, 0.015404332f, 0.015531912f, 0.015555983f, 0.01507354f, 0.015588721f, + 0.01545357f, 0.015513655f, 0.015537361f, 0.015617292f, 0.015471501f, 0.015559636f, 0.015541913f, + 0.015565485f, 0.015380409f, 0.015168384f, 0.0155151095f, 0.015469871f, 0.015443675f, 0.015554659f, + 0.015623292f, 0.014806481f, 0.015374577f, 0.015407367f, 0.015303424f, 0.015412778f, 0.015173398f, + 0.015220221f, 0.015319703f, 0.015124975f, 0.015372854f, 0.015297962f, 0.015397722f, 0.015355343f, + 0.015466366f, 0.01507015f, 0.015495513f, 0.015593667f, 0.015281979f, 0.015336113f, 0.015525f, + 0.01537925f, 0.015516909f, 0.015614616f, 0.015543677f, 0.015600901f, 0.0153762605f, 0.015399329f, + 0.015290953f, 0.015491776f, 0.015287561f, 0.015271302f, 0.015343454f, 0.015566604f, 0.015624354f, + 0.01533857f, 0.015119089f, 0.015481008f, 0.015398314f, 0.015596798f, 0.0153150465f, 0.015608612f, + 0.015555618f, 0.015332868f, 0.015389856f, 0.015581448f, 0.015621847f, 0.015410677f, 0.01556886f, + 0.015614897f, 0.01547879f, 0.015478665f, 0.015515525f, 0.01555785f, 0.01561863f, 0.015433328f, + 0.015305866f, 0.015573423f, 0.015373498f, 0.0155666135f, 0.015396729f, 0.015547626f, 0.014429122f, + 0.015496805f, 0.015291028f, 0.015550148f, 0.015425619f, 0.0155315865f, 0.015438886f, 0.015576545f, + 0.015619017f, 0.01515908f, 0.015479961f, 0.015447514f, 0.015065838f, 0.015309097f, 0.015131723f, + 0.014979966f, 0.014841583f, 0.015531611f, 0.015469328f, 0.015101345f, 0.015491165f, 0.0155728385f, + 0.015560919f, 0.015370855f}; + const std::vector fc3_scales = { + 0.015415549f, 0.015507627f, 0.014678219f, 0.015550405f, 0.015007719f, 0.015621224f, 0.0155345425f, + 0.015270567f, 0.015584674f, 0.015545895f, 0.015420519f, 0.015511904f, 0.015497334f, 0.015613152f, + 0.015344387f, 0.015462939f, 0.015408138f, 0.015263364f, 0.015522234f, 0.015557403f, 0.015617529f, + 0.0155323f, 0.015070785f, 0.0154183265f, 0.015569469f, 0.014966013f, 0.015585924f, 0.0155711975f, + 0.01525447f, 0.015368329f, 0.015493156f, 0.015439328f, 0.015451316f, 0.015313955f, 0.015007403f, + 0.015397709f, 0.015486734f, 0.01554385f, 0.015589319f, 0.015365845f, 0.0152554605f, 0.015575631f, + 0.015524423f, 0.015446551f, 0.01492084f, 0.015455352f, 0.014697226f, 0.015101928f, 0.01525531f, + 0.01557962f, 0.015178623f, 0.015425265f, 0.015473807f, 0.015434511f, 0.015518608f, 0.015348455f, + 0.014946166f, 0.0153529495f, 0.015595689f, 0.015601011f, 0.015585726f, 0.0155280195f, 0.014892634f, + 0.015474405f, 0.015582396f, 0.01517096f, 0.015513012f, 0.015467694f, 0.015459979f, 0.015562061f, + 0.015136767f, 0.015591653f, 0.015295904f, 0.014878606f, 0.015608272f, 0.015360581f, 0.015440369f, + 0.015552597f, 0.0153689645f, 0.015544422f, 0.015161956f, 0.015341356f, 0.015590522f, 0.0155716f, + 0.0153000355f, 0.015417134f, 0.015441434f, 0.015425701f, 0.015540993f, 0.015532201f, 0.015549095f, + 0.015335085f, 0.01554049f, 0.015028752f, 0.015245372f, 0.01556482f, 0.015607696f, 0.015421748f, + 0.0154471155f, 0.015398482f, 0.015602099f, 0.015455678f, 0.015591139f, 0.01557602f, 0.015448909f, + 0.0153864585f, 0.015211966f, 0.015580256f, 0.015525388f, 0.015311712f, 0.015527213f, 0.015249299f, + 0.015606547f, 0.0154935885f, 0.015555864f, 0.01537651f, 0.015581995f, 0.015337018f, 0.01547428f, + 0.015216509f, 0.015208464f, 0.015577957f, 0.015380967f, 0.015528679f, 0.015578562f, 0.015344413f, + 0.015526013f, 0.015194058f}; + const std::vector output = { + 0.04828f, -0.05322f, -0.11176f, 0.09344f, -0.02678f, 0.09827f, 0.06616f, -0.04233f, -0.03937f, 0.1582f, + -0.0437f, 0.04413f, 0.0931f, 0.11127f, -0.0747f, -0.10297f, -0.06226f, 0.02866f, -0.1395f, -0.008934f, + -0.0385f, 0.1564f, 0.1207f, -0.104f, 0.131f, -0.01776f, -0.00962f, 0.05615f, -0.0129f, -0.01724f, + -0.06555f, 0.00729f, -0.02585f, 0.01662f, 0.1351f, -0.02095f, 0.1703f, -0.0237f, -0.1381f, 0.10895f, + -0.0724f, 0.04358f, 0.1371f, -0.0707f, 0.02188f, -0.06122f, -0.03586f, -0.01924f, 0.01304f, -0.039f, + 0.12317f, -0.2336f, 0.0972f, -0.0862f, 0.05716f, 0.05075f, 0.1477f, 0.1316f, -0.05365f, -0.1301f, + 0.01836f, -0.09186f, 0.0641f, -0.10913f, -0.1576f, 0.0441f, 0.03537f, -0.062f, 0.06915f, 0.02954f, + 0.1605f, -0.05975f, -0.08435f, 0.1779f, -0.01181f, 0.001026f, 0.1284f, 0.1531f, 0.0571f, -0.1577f, + 0.05838f, 0.1444f, -0.02432f, 0.10065f, -0.04343f, -0.09296f, 0.0335f, -0.00582f, 0.004944f, -0.013054f, + -0.049f, 0.0776f, 0.04633f, 0.0746f, -0.1191f, -0.1118f, -0.209f, -0.09753f, -0.02882f, -0.01466f, + -0.08655f, 0.1167f, -0.02155f, 0.05896f, 0.0117f, -0.05618f, 0.0908f, 0.1324f, -0.04462f, 0.04077f, + -0.02385f, 0.01863f, 0.0729f, 0.1226f, -0.1261f, -0.0583f, 0.0774f, -0.1523f, 0.2018f, 0.1119f, + -0.04095f, -0.01188f, 0.1113f, 0.0502f, 0.00584f, -0.02325f, 0.02837f, 0.04144f}; + + RunQMoETest(input, router_probs, fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, fc1_scales, + fc2_scales, fc3_scales, output, num_rows, num_experts, hidden_size, inter_size, "silu", + 1, /*normalize_routing_weights*/ + 2 /*top_k*/); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py index 90b7da255081a..50292f186df15 100644 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py @@ -226,9 +226,9 @@ def __init__(self, config, batch_size, sequence_length): w2_list = [] w3_list = [] for i in range(self.num_experts): - w1_list.append(self.experts[i].w1.weight.transpose(0, 1)) - w2_list.append(self.experts[i].w2.weight.transpose(0, 1)) - w3_list.append(self.experts[i].w3.weight.transpose(0, 1)) + w1_list.append(self.experts[i].w1.weight) + w2_list.append(self.experts[i].w2.weight) + w3_list.append(self.experts[i].w3.weight) self.moe_experts_weight1 = torch.stack(w1_list, dim=0) self.moe_experts_weight2 = torch.stack(w2_list, dim=0) diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index dbf6ee7dabb0e..aa480a1af4587 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -249,9 +249,9 @@ def __init__( num_experts, in_features, hidden_features, - self.moe_experts.weight1, + self.moe_experts.weight1.transpose(1, 2), self.moe_experts.bias1, - self.moe_experts.weight2, + self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, )