Skip to content

Commit

Permalink
MoE with Expert Slicing (#18565)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Registered Sharded MoE op under contrib_op/cuda/collective with expert
slicing. The broadcast process happens just before adding second bias(if
has) and permutation undoing. Tensor slicing is planned but not included
in this PR.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems authored Dec 6, 2023
1 parent 871c529 commit c012e41
Show file tree
Hide file tree
Showing 16 changed files with 884 additions and 144 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
if (NOT onnxruntime_USE_NCCL)
list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
Expand Down
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ if (NOT onnxruntime_USE_NCCL)
# Those are string patterns to exclude. Do NOT use stars such as
# collective/*.cc or *.h.
list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#include "core/providers/cuda/cuda_common.h"
#include <cuda.h>

namespace onnxruntime {
namespace contrib {
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))

static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
if (type == DataTypeImpl::GetType<uint8_t>()) {
return ncclUint8;
} else if (type == DataTypeImpl::GetType<bool>()) {
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@

#if defined(ORT_USE_NCCL)
#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <tuple>
#include <nccl.h>
#include <sstream>
#include <string>
#endif

namespace onnxruntime {
namespace contrib {
namespace cuda {

#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))

ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type);

// -----------------------------------------------------------------------
// Defines a new version of nccl classes
// that independent with training::DistributedRunContext, only rely on MPI
Expand Down
204 changes: 204 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
#include "sharded_moe.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ShardedMoE, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.MayInplace(0, 0) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
ShardedMoE<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)

using namespace ONNX_NAMESPACE;

template <typename T>
ShardedMoE<T>::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("local_experts_start_index", &local_experts_start_index_).IsOK());
rank_to_experts_start_index_.resize(nccl_->Size());
// Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized.
rank_to_experts_start_index_[0] = std::numeric_limits<int64_t>::min();
}

template <typename T>
Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
auto stream = context->GetComputeStream();

auto& device_prop = GetDeviceProp();
const int sm = device_prop.major * 10 + device_prop.minor;

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

// Create a {Rank, ExpertsStartIndex} map on Host.
AutoDestoryCudaEvent cuda_event;
cudaEvent_t& copy_event = cuda_event.Get();
ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));

const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
const Tensor* fc2_experts_weights = context->Input<Tensor>(3);
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(5);

MoEParameters moe_params;
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
fc1_experts_bias_optional, fc2_experts_bias_optional));
ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
"num_experts should be divisible by world_size");

ort_fastertransformer::CutlassMoeFCRunner<CudaT, CudaT> moe_runner(sm);

size_t ws_size =
moe_runner.getWorkspaceSize(static_cast<int>(moe_params.num_rows), static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(k_));

size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int);

// TODO: allocate one buffer and reuse it.
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, ws_size, false, stream);
IAllocatorUniquePtr<void> fc2_output = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
IAllocatorUniquePtr<void> fc2_output_bc = IAllocator::MakeUniquePtr<void>(allocator, fc2_output_size, false, stream);
IAllocatorUniquePtr<void> expert_scales =
IAllocator::MakeUniquePtr<void>(allocator, expert_scales_size, false, stream);
IAllocatorUniquePtr<void> expanded_source_row_to_expanded_dest_row =
IAllocator::MakeUniquePtr<void>(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream);
IAllocatorUniquePtr<void> expert_for_source_row =
IAllocator::MakeUniquePtr<void>(allocator, expert_for_source_row_size, false, stream);

// fc1_scales and fc2_scales are used in quantized MoE
const CudaT* fc1_scales_ptr = nullptr;
const CudaT* fc2_scales_ptr = nullptr;

moe_runner.run_moe_fc(reinterpret_cast<const CudaT*>(input->template Data<T>()),
reinterpret_cast<const CudaT*>(router_probs->template Data<T>()),
reinterpret_cast<const CudaT*>(fc1_experts_weights->template Data<T>()),
std::move(fc1_scales_ptr),
fc1_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc1_experts_bias_optional->template Data<T>()),
activation_type_, reinterpret_cast<const CudaT*>(fc2_experts_weights->template Data<T>()),
std::move(fc2_scales_ptr), static_cast<int>(moe_params.num_rows),
static_cast<int>(moe_params.hidden_size),
static_cast<int>(moe_params.inter_size), static_cast<int>(moe_params.num_experts),
static_cast<int>(moe_params.local_num_experts), static_cast<int>(local_experts_start_index_),
static_cast<int>(k_), reinterpret_cast<char*>(work_space.get()),
reinterpret_cast<CudaT*>(fc2_output.get()), reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), Stream(context));

Tensor* output = context->Output(0, input->Shape());

size_t stride_count = moe_params.hidden_size;
size_t stride_bytes = stride_count * sizeof(CudaT);
int64_t total_past_rows = 0;
int64_t total_covered_rows = 0;
if (copy_event != nullptr) {
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event));
}
NCCL_RETURN_IF_ERROR(ncclGroupStart());
for (int rank = 0; rank < nccl_->Size(); ++rank) {
int64_t experts_start_index = rank_to_experts_start_index_[rank];
moe_runner.get_total_rows_info(experts_start_index,
moe_params.local_num_experts,
total_past_rows,
total_covered_rows);
const char* src = reinterpret_cast<const char*>(fc2_output.get()) + total_past_rows * stride_bytes;
char* dst = reinterpret_cast<char*>(fc2_output_bc.get()) + total_past_rows * stride_bytes;
NCCL_RETURN_IF_ERROR(ncclBroadcast(src,
dst,
total_covered_rows * stride_count,
GetNcclDataType(input->DataType()),
rank,
nccl_->Comm(),
Stream(context)));
}
NCCL_RETURN_IF_ERROR(ncclGroupEnd());

ort_fastertransformer::finalize_moe_routing_kernelLauncher(
reinterpret_cast<CudaT*>(fc2_output_bc.get()), reinterpret_cast<CudaT*>(output->template MutableData<T>()),
fc2_experts_bias_optional == nullptr
? nullptr
: reinterpret_cast<const CudaT*>(fc2_experts_bias_optional->template Data<T>()),
reinterpret_cast<CudaT*>(expert_scales.get()),
reinterpret_cast<int*>(expanded_source_row_to_expanded_dest_row.get()),
reinterpret_cast<int*>(expert_for_source_row.get()), static_cast<int>(moe_params.num_rows),
static_cast<int>(moe_params.hidden_size), static_cast<int>(k_), Stream(context));

return Status::OK();
}

template <typename T>
Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
OpKernelContext* context,
cudaEvent_t& cuda_event) const {
if (rank_to_experts_start_index_[0] != std::numeric_limits<int64_t>::min()) {
return Status::OK();
}

auto stream = context->GetComputeStream();

using IndexType = int64_t;
size_t IndexTypeSize = sizeof(IndexType);

IAllocatorUniquePtr<IndexType> experts_start_index_d =
IAllocator::MakeUniquePtr<IndexType>(allocator, 1, false, stream);
IAllocatorUniquePtr<IndexType> rank_to_experts_start_index_d =
IAllocator::MakeUniquePtr<IndexType>(allocator, nccl_->Size(), false, stream);

// Only happens in the first run.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(),
&local_experts_start_index_,
IndexTypeSize,
cudaMemcpyHostToDevice,
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast<const char*>(experts_start_index_d.get()),
reinterpret_cast<char*>(rank_to_experts_start_index_d.get()),
1,
GetNcclDataType(DataTypeImpl::GetType<IndexType>()),
nccl_->Comm(),
Stream(context)));
// The const_cast<> violates the const modifier to make sure the synchronization happens only once per session.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast<int64_t*>(rank_to_experts_start_index_.data()),
rank_to_experts_start_index_d.get(),
nccl_->Size() * IndexTypeSize,
cudaMemcpyDeviceToHost,
Stream(context)));

CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming));
CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context)));

return Status::OK();
}
#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
36 changes: 36 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
#include "contrib_ops/cuda/moe/moe_base.h"
#include "core/common/common.h"
#include "nccl_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

using namespace onnxruntime::cuda;

template <typename T>
class ShardedMoE final : public NcclKernel, public MoEBase {
public:
explicit ShardedMoE(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;

private:
Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const;

int64_t local_experts_start_index_;
std::vector<int64_t> rank_to_experts_start_index_;
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);

Expand Down Expand Up @@ -364,6 +367,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,

Expand Down
Loading

0 comments on commit c012e41

Please sign in to comment.