Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Mar 26, 2024
1 parent f81effb commit 3c5028d
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 28 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ set(contrib_ops_excluded_files
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/moe_quantization.h"
"quantization/moe_quantization.cc"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
MoEParameters moe_params(tensor_shards_);
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional));
fc3_experts_bias_optional, false));

ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");
Expand Down
14 changes: 2 additions & 12 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
#include "cub/util_type.cuh"
#endif

#include "core/providers/cuda/cu_inc/common.cuh"

namespace ort_fastertransformer {
static constexpr int WARP_SIZE = 32;

Expand Down Expand Up @@ -645,23 +647,11 @@ inline __device__ float4 operator*(const float4 a, const float4 b) {
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
inline __device__ half operator*(const half a, const half b) { return __float2half(__half2float(a) * __half2float(b)); }

inline __device__ half2 operator*(const half2 a, const half2 b) { return make_half2(a.x * b.x, a.y * b.y); }
#endif

inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
Half4 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
return result;
#else
return Half4{a.x * b.x, a.y * b.y};
#endif
}

} // anonymous namespace
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
MoEParameters moe_params;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional));
fc3_experts_bias_optional, false));

typedef typename ToCudaType<T>::MappedType CudaT;
auto stream = context->GetComputeStream();
Expand Down
14 changes: 0 additions & 14 deletions onnxruntime/contrib_ops/cuda/moe/moe_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,6 @@ class MoEBase {
return Status::OK();
}

Status CheckInputs(MoEParameters parameters,
const Tensor* input,
const Tensor* router_probs,
const Tensor* fc1_experts_weights,
const Tensor* fc1_experts_bias_optional,
const Tensor* fc2_experts_weights,
const Tensor* fc2_experts_bias_optional,
const Tensor* fc3_experts_weights_optional,
const Tensor* fc3_experts_bias_optional) const {
return CheckInputs(parameters, input, router_probs, fc1_experts_weights, fc1_experts_bias_optional,
fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional,
fc3_experts_bias_optional, false);
}

protected:
MoEBase(const OpKernelInfo& op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());
Expand Down

0 comments on commit 3c5028d

Please sign in to comment.