Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Nov 25, 2023
1 parent 953ddd3 commit 3be101b
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 84 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ if (NOT onnxruntime_USE_NCCL)
# Those are string patterns to exclude. Do NOT use stars such as
# collective/*.cc or *.h.
list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,29 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
if (type == DataTypeImpl::GetType<uint8_t>()) {
return ncclUint8;
} else if (type == DataTypeImpl::GetType<bool>()) {
// CUDA bool is 8-bit large.
return ncclUint8;
} else if (type == DataTypeImpl::GetType<int8_t>()) {
return ncclInt8;
} else if (type == DataTypeImpl::GetType<int32_t>()) {
return ncclInt32;
} else if (type == DataTypeImpl::GetType<int64_t>()) {
return ncclInt64;
} else if (type == DataTypeImpl::GetType<MLFloat16>()) {
return ncclFloat16;
} else if (type == DataTypeImpl::GetType<float>()) {
return ncclFloat32;
} else if (type == DataTypeImpl::GetType<double>()) {
return ncclFloat64;
} else {
ORT_THROW("Tensor type not supported in NCCL.");
}
}

namespace IPC {
#define FLLOG LOGS_DEFAULT(VERBOSE)
#define FLLOGERRNO LOGS_DEFAULT(WARNING) << "error:" << strerror(errno)
Expand Down
29 changes: 5 additions & 24 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

#if defined(ORT_USE_NCCL)
#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <tuple>
#include <nccl.h>
#include <sstream>
#include <string>
#endif

namespace onnxruntime {
Expand All @@ -20,28 +20,9 @@ namespace cuda {

#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))

static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
if (type == DataTypeImpl::GetType<uint8_t>()) {
return ncclUint8;
} else if (type == DataTypeImpl::GetType<bool>()) {
// CUDA bool is 8-bit large.
return ncclUint8;
} else if (type == DataTypeImpl::GetType<int8_t>()) {
return ncclInt8;
} else if (type == DataTypeImpl::GetType<int32_t>()) {
return ncclInt32;
} else if (type == DataTypeImpl::GetType<int64_t>()) {
return ncclInt64;
} else if (type == DataTypeImpl::GetType<MLFloat16>()) {
return ncclFloat16;
} else if (type == DataTypeImpl::GetType<float>()) {
return ncclFloat32;
} else if (type == DataTypeImpl::GetType<double>()) {
return ncclFloat64;
} else {
ORT_THROW("Tensor type not supported in NCCL.");
}
}
#if defined(ORT_USE_NCCL)
ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type);
#endif

// -----------------------------------------------------------------------
// Defines a new version of nccl classes
Expand Down
104 changes: 62 additions & 42 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ using namespace ONNX_NAMESPACE;
template <typename T>
ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("local_experts_start_index", &local_experts_start_index_).IsOK());
rank_to_experts_start_index_.resize(nccl_->Size());
// Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized.
rank_to_experts_start_index_[0] = std::numeric_limits<int64_t>::min();
}

template <typename T>
Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
const ncclComm_t comm = nccl_->Comm();

typedef typename ToCudaType<T>::MappedType CudaT;
auto stream = context->GetComputeStream();

Expand All @@ -50,31 +51,8 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

// Creating a {Rank, ExpertsStartIndex} map on Host.
std::vector<int64_t> rank_to_experts_start_index(nccl_->Size());
IAllocatorUniquePtr<int64_t> experts_start_index_d =
IAllocator::MakeUniquePtr<int64_t>(allocator, 1, false, stream);
IAllocatorUniquePtr<int64_t> rank_to_experts_start_index_d =
IAllocator::MakeUniquePtr<int64_t>(allocator, nccl_->Size(), false, stream);

CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(),
&local_experts_start_index_,
sizeof(int64_t),
cudaMemcpyHostToDevice,
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast<const char*>(experts_start_index_d.get()),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()),
1,
ncclInt64,
comm,
Stream(context)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(rank_to_experts_start_index.data(),
rank_to_experts_start_index_d.get(),
nccl_->Size() * sizeof(int64_t),
cudaMemcpyDeviceToHost,
Stream(context)));

CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context)));
// Create a {Rank, ExpertsStartIndex} map on Host.
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context));

const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
Expand Down Expand Up @@ -135,21 +113,25 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {

Tensor* output = context->Output(0, input->Shape());

size_t base_offset = moe_params.hidden_size * sizeof(CudaT);
ncclDataType_t dtype = GetNcclDataType(input->DataType());
NCCL_RETURN_IF_ERROR(ncclGroupStart()); // flacky results with ncclBroadcast + ncclGroupStart/End
for (int r = 0; r < nccl_->Size(); ++r) {
int64_t experts_start_index = rank_to_experts_start_index[r];
int64_t total_past_rows = 0;
int64_t total_covered_rows = 0;
moe_runner.get_total_rows_info(experts_start_index, moe_params.local_num_experts, total_past_rows, total_covered_rows);
// std::cout << "rank: " << r << ", experts_start_index: " << experts_start_index << ", total_past_rows: " << total_past_rows << ", total_covered_rows: " << total_covered_rows << std::endl;
NCCL_RETURN_IF_ERROR(ncclBroadcast(reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * base_offset,
reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * base_offset,
total_covered_rows * moe_params.hidden_size,
dtype,
r,
comm,
size_t stride_count = moe_params.hidden_size;
size_t stride_bytes = stride_count * sizeof(CudaT);
int64_t total_past_rows = 0;
int64_t total_covered_rows = 0;
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
moe_runner.get_total_rows_info(experts_start_index,
moe_params.local_num_experts,
total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
dst,
total_covered_rows * stride_count,
GetNcclDataType(input->DataType()),
rank,
nccl_->Comm(),
Stream(context)));
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
Expand All @@ -167,6 +149,44 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
}

template <typename T>
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, OpKernelContext* context) const {
if (rank_to_experts_start_index_[0] != std::numeric_limits<int64_t>::min()) {

Check warning on line 154 in onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc#L154

Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc:154:  Add #include <limits> for numeric_limits<>  [build/include_what_you_use] [4]
return Status::OK();
}

auto stream = context->GetComputeStream();

using IndexType = int64_t;
size_t IndexTypeSize = sizeof(IndexType);

IAllocatorUniquePtr<IndexType> experts_start_index_d =
IAllocator::MakeUniquePtr<IndexType>(allocator, 1, false, stream);
IAllocatorUniquePtr<IndexType> rank_to_experts_start_index_d =
IAllocator::MakeUniquePtr<IndexType>(allocator, nccl_->Size(), false, stream);

// Only happens in the first run.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(),
&local_experts_start_index_,
IndexTypeSize,
cudaMemcpyHostToDevice,
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast<const char*>(experts_start_index_d.get()),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()),
1,
GetNcclDataType(DataTypeImpl::GetType<IndexType>()),
nccl_->Comm(),
Stream(context)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast<int64_t*>(rank_to_experts_start_index_.data()),
rank_to_experts_start_index_d.get(),
nccl_->Size() * IndexTypeSize,
cudaMemcpyDeviceToHost,
Stream(context)));

CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context)));

return Status::OK();
}
#endif

} // namespace cuda
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx) const;

int64_t local_experts_start_index_;
std::vector<int64_t> rank_to_experts_start_index_;

Check warning on line 29 in onnxruntime/contrib_ops/cuda/collective/sharded_moe.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/sharded_moe.h#L29

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/sharded_moe.h:29:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
};

#endif
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class CutlassMoeFCRunner {
void dispatch_activations(int64_t* total_rows_before_expert, int num_experts, int local_num_experts,
int local_experts_start_index, cudaStream_t stream);

void get_total_rows_info(int64_t experts_start_index, int64_t local_num_experts, int64_t& total_past_rows, int64_t& total_covered_rows){
int experts_end_index = experts_start_index + local_num_experts - 1;
void get_total_rows_info(int64_t experts_start_index, int64_t local_num_experts, int64_t& total_past_rows, int64_t& total_covered_rows) {

Check warning on line 133 in onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h#L133

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h:133:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int64_t experts_end_index = experts_start_index + local_num_experts - 1;
total_past_rows = 0;
if (experts_start_index > 0) {
total_past_rows = total_rows_before_expert_host_[experts_start_index - 1];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import unittest
from mpi4py import MPI
import onnxruntime

import numpy as np
from mpi4py import MPI
from onnx import TensorProto, helper

import onnxruntime

np.random.seed(3)

comm = MPI.COMM_WORLD
Expand Down Expand Up @@ -51,7 +57,7 @@ def create_moe_onnx_graph(
fc2_experts_bias,
local_experts_start_index=-1,
):
use_sharded_moe = True if local_experts_start_index >= 0 else False
use_sharded_moe = local_experts_start_index >= 0
nodes = [
helper.make_node(
"MoE",
Expand Down Expand Up @@ -164,18 +170,6 @@ def test_moe_with_expert_slicing(
num_experts,
num_rows,
):
print_out(
"hidden_size: ",
hidden_size,
" inter_size: ",
inter_size,
" num_experts: ",
num_experts,
" num_rows: ",
num_rows,
" world_size: ",
get_size(),
)
local_experts_start_index = local_rank * num_experts // get_size()

fc1_experts_weights_all = np.random.rand(num_experts, hidden_size, inter_size).astype(NP_TYPE)
Expand Down Expand Up @@ -235,6 +229,20 @@ def test_moe_with_expert_slicing(

assert np.allclose(output[0], sharded_output[0], atol=THRESHOLD, rtol=THRESHOLD)

print_out(
"hidden_size: ",
hidden_size,
" inter_size: ",
inter_size,
" num_experts: ",
num_experts,
" num_rows: ",
num_rows,
" world_size: ",
get_size(),
" Parity: OK",
)


class TestMoE(unittest.TestCase):
def test_moe_expert_slicing(self):
Expand Down

0 comments on commit 3be101b

Please sign in to comment.