diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index cadb06bb38707..0051f241e4f9b 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -60,6 +60,8 @@ set(contrib_ops_excluded_files "quantization/matmul_nbits.cc" "quantization/matmul_nbits.cuh" "quantization/matmul_nbits.cu" + "quantization/moe_quantization.h" + "quantization/moe_quantization.cc" "quantization/quantize_dequantize_linear.cc" "quantization/qordered_ops/qordered_attention_impl.cu" "quantization/qordered_ops/qordered_attention_impl.h" diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 2efc37cf98010..75100f5fe0a86 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -71,7 +71,7 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { MoEParameters moe_params(tensor_shards_); ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional)); + fc3_experts_bias_optional, false)); ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 294925538b015..5c50be4eae57e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -47,6 +47,8 @@ #include "cub/util_type.cuh" #endif +#include "core/providers/cuda/cu_inc/common.cuh" + namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; @@ -645,23 +647,11 @@ inline __device__ float4 operator*(const float4 a, const float4 b) { return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ - ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) -inline __device__ half operator*(const half a, const half b) { return __float2half(__half2float(a) * __half2float(b)); } - -inline __device__ half2 operator*(const half2 a, const half2 b) { return make_half2(a.x * b.x, a.y * b.y); } -#endif - inline __device__ Half4 operator*(const Half4 a, const Half4 b) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ - ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) Half4 result; result.x = a.x * b.x; result.y = a.y * b.y; return result; -#else - return Half4{a.x * b.x, a.y * b.y}; -#endif } } // anonymous namespace diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index b13aab959fc48..a486db7a07fda 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -48,7 +48,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { MoEParameters moe_params; ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional)); + fc3_experts_bias_optional, false)); typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index ee28e3066325e..9155a1364c6dd 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -172,20 +172,6 @@ class MoEBase { return Status::OK(); } - Status CheckInputs(MoEParameters parameters, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional) const { - return CheckInputs(parameters, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional, - fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, - fc3_experts_bias_optional, false); - } - protected: MoEBase(const OpKernelInfo& op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK());