Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Mar 19, 2024
1 parent 69f53b8 commit 42de225
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 16 deletions.
7 changes: 5 additions & 2 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
// Create a {Rank, ExpertsStartIndex} map on Host.
AutoDestoryCudaEvent cuda_event;
cudaEvent_t& copy_event = cuda_event.Get();
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));

const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
Expand All @@ -77,6 +76,10 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");

if (moe_params.parallel_type == MoEParallelType::EP || moe_params.parallel_type == MoEParallelType::EPAndTP) {
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
}

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm,
fc3_experts_weights_optional != nullptr,
normalize_routing_weights_);
Expand Down Expand Up @@ -136,7 +139,7 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
}

if (moe_params.parallel_type == MoEParallelType::EPAndTP) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EPAndTP is not supported yet");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expert and Tensor Parallelism is not supported yet");
}

if (moe_params.parallel_type == MoEParallelType::TP) {
Expand Down
81 changes: 72 additions & 9 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,

const bool should_process_row = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
T output_row_sum = T(0.f);
float output_row_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
thread_kvp.key = 0;
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
Expand Down Expand Up @@ -159,7 +159,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
if (normalize_routing_weights && k_idx == k - 1) {
#pragma unroll
for (int ki = 0; ki < k; ++ki) {
output[idx - ki] /= output_row_sum;
output[idx - ki] = T(static_cast<float>(output[idx - ki]) / output_row_sum);
}
}
}
Expand Down Expand Up @@ -575,7 +575,7 @@ size_t CutlassMoeFCRunner<T, WeightType, Enable>::getWorkspaceSize(size_t num_ro
total_ws_bytes += num_softmax_outs * sizeof(T);
const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T);
const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
sorter_.update_num_experts(static_cast<int>(num_experts));

size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
Expand Down Expand Up @@ -618,6 +618,59 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::configure_ws_ptrs(char* ws_ptr,
}
}

namespace {

struct __align__(8) Half4 {
half2 x;
half2 y;
};

// TODO(wy): move to common header
template <typename T>
struct T4;
template <>
struct T4<float> {
using Type = float4;
};
template <>
struct T4<half> {
using Type = Half4;
};

template <typename T>
struct T2;
template <>
struct T2<float> {
using Type = float2;
};
template <>
struct T2<half> {
using Type = half2;
};

inline __device__ float2 operator*(const float2 a, const float2 b) {
return make_float2(a.x * b.x, a.y * b.y);
}

inline __device__ float4 operator*(const float4 a, const float4 b) {
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}

inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
Half4 result;
result.x.x = (half)((float)a.x.x * (float)b.x.x);

Check warning on line 662 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:662: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
result.x.y = (half)((float)a.x.y * (float)b.x.y);

Check warning on line 663 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:663: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
result.y.x = (half)((float)a.y.x * (float)b.y.x);

Check warning on line 664 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:664: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
result.y.y = (half)((float)a.y.y * (float)b.y.y);

Check warning on line 665 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:665: Using C-style cast. Use static_cast<float>(...) instead [readability/casting] [4]
return result;
#else
return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
#endif
};

Check warning on line 670 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:670: You don't need a ; after a } [readability/braces] [4]

} // anonymous namespace

Check warning on line 672 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu:672: At least two spaces is best between code and comments [whitespace/comments] [2]

template <typename T>
__global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_size) {
int const tid = threadIdx.x;
Expand All @@ -627,20 +680,30 @@ __global__ void elementWiseMulKernel(T* output, T const* input, size_t inter_siz
input = input + token * inter_size;
for (int i = tid; i < inter_size; i += blockDim.x) {
T fc1_value = input[i];
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
output[i] = T(float(fc1_value) * float(output[i]));
#else
output[i] = fc1_value * output[i];
#endif
}
}

template <typename T>
void elementWiseMul(T* output, T const* input, int inter_size, int num_tokens, cudaStream_t stream) {
int const blocks = num_tokens;
int const threads = std::min(inter_size, 1024);

elementWiseMulKernel<<<blocks, threads, 0, stream>>>(output, input, inter_size);
if (inter_size & 3 == 0) {
using vec_type = typename T4<T>::Type;
int const threads = std::min(inter_size / 4, 1024);
elementWiseMulKernel<vec_type><<<blocks, threads, 0, stream>>>(reinterpret_cast<vec_type*>(output),
reinterpret_cast<vec_type const*>(input),
inter_size / 4);
} else if (inter_size & 1 == 0) {
using vec_type = typename T2<T>::Type;
int const threads = std::min(inter_size / 2, 1024);
elementWiseMulKernel<vec_type><<<blocks, threads, 0, stream>>>(reinterpret_cast<vec_type*>(output),
reinterpret_cast<vec_type const*>(input),
inter_size / 2);
} else {
int const threads = std::min(inter_size, 1024);
elementWiseMulKernel<T><<<blocks, threads, 0, stream>>>(output, input, inter_size);
}
}

template <typename T, typename WeightType, typename Enable>
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
constexpr const char* MoE_ver1_doc = R"DOC(
Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1,
GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, Vision MOE(https://arxiv.org/pdf/2106.05974.pdf)
usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral)
usually uses top 32 experts and Mixtral(https://huggingface.co/blog/mixtral).
)DOC";

ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,15 @@ struct _IsNan {
template <>
struct _IsNan<half> {
__device__ __inline__ bool operator()(half a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
> MLFloat16::kPositiveInfinityBits;
}
};

template <>
struct _IsNan<BFloat16> {
__device__ __inline__ bool operator()(BFloat16 a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
> BFloat16::kPositiveInfinityBits;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
self.router_aux_loss_coef = router_aux_loss_coef


class MixtralBLockSparseTop2MLP(nn.Module):
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
Expand Down Expand Up @@ -220,7 +220,7 @@ def __init__(self, config, batch_size, sequence_length):
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)])
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

w1_list = []
w2_list = []
Expand Down

0 comments on commit 42de225

Please sign in to comment.