From c012e41f9385303f486b644cd679fdb2784fe854 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 6 Dec 2023 00:56:38 +0000 Subject: [PATCH] MoE with Expert Slicing (#18565) ### Description Registered Sharded MoE op under contrib_op/cuda/collective with expert slicing. The broadcast process happens just before adding second bias(if has) and permutation undoing. Tensor slicing is planned but not included in this PR. ### Motivation and Context --- cmake/onnxruntime_providers_cuda.cmake | 2 + cmake/onnxruntime_rocm_hipify.cmake | 2 + .../cuda/bert/transformer_cuda_common.h | 2 +- .../cuda/collective/nccl_kernels.cc | 4 +- .../cuda/collective/nccl_kernels.h | 8 +- .../cuda/collective/sharded_moe.cc | 204 ++++++++++++++ .../contrib_ops/cuda/collective/sharded_moe.h | 36 +++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 6 + .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 96 ++++++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 27 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 118 ++------ onnxruntime/contrib_ops/cuda/moe/moe.h | 25 +- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 172 ++++++++++++ .../core/graph/contrib_ops/collective_defs.cc | 54 ++++ .../transformers/sharded_moe/run_script.sh | 10 + .../sharded_moe/test_sharded_moe.py | 262 ++++++++++++++++++ 16 files changed, 884 insertions(+), 144 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/sharded_moe.h create mode 100644 onnxruntime/contrib_ops/cuda/moe/moe_base.h create mode 100644 onnxruntime/test/python/transformers/sharded_moe/run_script.sh create mode 100644 onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index cf298aee9fa85..84d1376f99d5e 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -34,6 +34,8 @@ if (NOT onnxruntime_USE_NCCL) list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 980bd59b22c3f..f70961a66329a 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,6 +109,8 @@ if (NOT onnxruntime_USE_NCCL) # Those are string patterns to exclude. Do NOT use stars such as # collective/*.cc or *.h. list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") + list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h") + list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc") list(APPEND contrib_ops_excluded_files "collective/sharding.cc") list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") diff --git a/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h b/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h index faf9310c4c3fd..a0da24210459c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h +++ b/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h @@ -3,7 +3,7 @@ #pragma once -#include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index 574a3133de815..0f42363bca22d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -24,9 +24,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr)) - -static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) { +ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) { if (type == DataTypeImpl::GetType()) { return ncclUint8; } else if (type == DataTypeImpl::GetType()) { diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 7fc26e6be57b9..9ea61f2bd952d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -7,17 +7,21 @@ #if defined(ORT_USE_NCCL) #include -#include #include -#include +#include #include #include +#include #endif namespace onnxruntime { namespace contrib { namespace cuda { +#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr)) + +ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type); + // ----------------------------------------------------------------------- // Defines a new version of nccl classes // that independent with training::DistributedRunContext, only rely on MPI diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc new file mode 100644 index 0000000000000..40a667ffd5d83 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/transformer_cuda_common.h" +#include "sharded_moe.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ShardedMoE, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ShardedMoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("local_experts_start_index", &local_experts_start_index_).IsOK()); + rank_to_experts_start_index_.resize(nccl_->Size()); + // Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized. + rank_to_experts_start_index_[0] = std::numeric_limits::min(); +} + +template +Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + auto stream = context->GetComputeStream(); + + auto& device_prop = GetDeviceProp(); + const int sm = device_prop.major * 10 + device_prop.minor; + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + // Create a {Rank, ExpertsStartIndex} map on Host. + AutoDestoryCudaEvent cuda_event; + cudaEvent_t& copy_event = cuda_event.Get(); + ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event)); + + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc2_experts_weights = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_bias_optional = context->Input(5); + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc1_experts_bias_optional, fc2_experts_bias_optional)); + ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, + "num_experts should be divisible by world_size"); + + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); + + size_t ws_size = + moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(k_)); + + size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); + + // TODO: allocate one buffer and reuse it. + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); + IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr fc2_output_bc = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); + IAllocatorUniquePtr expert_scales = + IAllocator::MakeUniquePtr(allocator, expert_scales_size, false, stream); + IAllocatorUniquePtr expanded_source_row_to_expanded_dest_row = + IAllocator::MakeUniquePtr(allocator, expanded_source_row_to_expanded_dest_row_size, false, stream); + IAllocatorUniquePtr expert_for_source_row = + IAllocator::MakeUniquePtr(allocator, expert_for_source_row_size, false, stream); + + // fc1_scales and fc2_scales are used in quantized MoE + const CudaT* fc1_scales_ptr = nullptr; + const CudaT* fc2_scales_ptr = nullptr; + + moe_runner.run_moe_fc(reinterpret_cast(input->template Data()), + reinterpret_cast(router_probs->template Data()), + reinterpret_cast(fc1_experts_weights->template Data()), + std::move(fc1_scales_ptr), + fc1_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc1_experts_bias_optional->template Data()), + activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), + std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(moe_params.local_num_experts), static_cast(local_experts_start_index_), + static_cast(k_), reinterpret_cast(work_space.get()), + reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), Stream(context)); + + Tensor* output = context->Output(0, input->Shape()); + + size_t stride_count = moe_params.hidden_size; + size_t stride_bytes = stride_count * sizeof(CudaT); + int64_t total_past_rows = 0; + int64_t total_covered_rows = 0; + if (copy_event != nullptr) { + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(copy_event)); + } + NCCL_RETURN_IF_ERROR(ncclGroupStart()); + for (int rank = 0; rank < nccl_->Size(); ++rank) { + int64_t experts_start_index = rank_to_experts_start_index_[rank]; + moe_runner.get_total_rows_info(experts_start_index, + moe_params.local_num_experts, + total_past_rows, + total_covered_rows); + const char* src = reinterpret_cast(fc2_output.get()) + total_past_rows * stride_bytes; + char* dst = reinterpret_cast(fc2_output_bc.get()) + total_past_rows * stride_bytes; + NCCL_RETURN_IF_ERROR(ncclBroadcast(src, + dst, + total_covered_rows * stride_count, + GetNcclDataType(input->DataType()), + rank, + nccl_->Comm(), + Stream(context))); + } + NCCL_RETURN_IF_ERROR(ncclGroupEnd()); + + ort_fastertransformer::finalize_moe_routing_kernelLauncher( + reinterpret_cast(fc2_output_bc.get()), reinterpret_cast(output->template MutableData()), + fc2_experts_bias_optional == nullptr + ? nullptr + : reinterpret_cast(fc2_experts_bias_optional->template Data()), + reinterpret_cast(expert_scales.get()), + reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), + reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); + + return Status::OK(); +} + +template +Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, + OpKernelContext* context, + cudaEvent_t& cuda_event) const { + if (rank_to_experts_start_index_[0] != std::numeric_limits::min()) { + return Status::OK(); + } + + auto stream = context->GetComputeStream(); + + using IndexType = int64_t; + size_t IndexTypeSize = sizeof(IndexType); + + IAllocatorUniquePtr experts_start_index_d = + IAllocator::MakeUniquePtr(allocator, 1, false, stream); + IAllocatorUniquePtr rank_to_experts_start_index_d = + IAllocator::MakeUniquePtr(allocator, nccl_->Size(), false, stream); + + // Only happens in the first run. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(experts_start_index_d.get(), + &local_experts_start_index_, + IndexTypeSize, + cudaMemcpyHostToDevice, + Stream(context))); + NCCL_RETURN_IF_ERROR(ncclAllGather(reinterpret_cast(experts_start_index_d.get()), + reinterpret_cast(rank_to_experts_start_index_d.get()), + 1, + GetNcclDataType(DataTypeImpl::GetType()), + nccl_->Comm(), + Stream(context))); + // The const_cast<> violates the const modifier to make sure the synchronization happens only once per session. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(rank_to_experts_start_index_.data()), + rank_to_experts_start_index_d.get(), + nccl_->Size() * IndexTypeSize, + cudaMemcpyDeviceToHost, + Stream(context))); + + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&cuda_event, cudaEventDisableTiming)); + CUDA_RETURN_IF_ERROR(cudaEventRecord(cuda_event, Stream(context))); + + return Status::OK(); +} +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h new file mode 100644 index 0000000000000..5ea4ae59c4020 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "contrib_ops/cuda/moe/moe_base.h" +#include "core/common/common.h" +#include "nccl_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +using namespace onnxruntime::cuda; + +template +class ShardedMoE final : public NcclKernel, public MoEBase { + public: + explicit ShardedMoE(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + Status SynchronizeExpertsStartIndex(AllocatorPtr& alloc, OpKernelContext* ctx, cudaEvent_t& cuda_event) const; + + int64_t local_experts_start_index_; + std::vector rank_to_experts_start_index_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 108eea1a73fe9..7875ac75b8188 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -165,6 +165,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -364,6 +367,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 398ce4ee9880f..f4f2b49032d23 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include #include @@ -501,8 +503,27 @@ __global__ void compute_total_rows_before_expert_kernel(const int* sorted_expert total_rows_before_expert[expert] = find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert); } +__global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, int num_experts, + int local_num_experts, int local_experts_start_index) { + const int expert = blockIdx.x * blockDim.x + threadIdx.x; + const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; + + int total_past_rows = 0; + if (local_experts_start_index > 0) { + total_past_rows = total_rows_before_expert[local_experts_start_index - 1]; + } + + if (expert < local_experts_start_index || expert > local_experts_end_index) { + return; + } + + total_rows_before_expert[expert] -= total_past_rows; +} + template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { + total_past_rows_ = 0; + total_covered_rows_ = 0; moe_gemm_runner_.initialize(sm_version); } @@ -549,7 +570,6 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, const int interbuf_size = static_cast(pad_to_multiple_of_16(k * num_rows * inter_size)); const int padded_experts = static_cast(pad_to_multiple_of_16(num_experts)); const int num_moe_inputs = static_cast(pad_to_multiple_of_16(k * num_rows)); - // const int num_softmax_outs = pad_to_multiple_of_16(num_rows * num_experts); source_rows_ = (int*)ws_ptr; permuted_rows_ = source_rows_ + num_moe_inputs; @@ -573,8 +593,9 @@ void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int k, char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, - int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { + int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, + const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream) { static constexpr bool scales_required = std::is_same::value || std::is_same::value; @@ -608,12 +629,23 @@ void CutlassMoeFCRunner::run_moe_fc( compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, total_rows_before_expert_, stream); - moe_gemm_runner_.moe_gemm_bias_act(permuted_data_, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_, - total_rows_before_expert_, expanded_active_expert_rows, inter_size, hidden_size, - num_experts, fc1_activation_type, stream); + if (local_num_experts < num_experts) { + dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, stream); + } - moe_gemm_runner_.moe_gemm(fc1_result_, fc2_expert_weights, fc2_scales, fc2_result, total_rows_before_expert_, - expanded_active_expert_rows, hidden_size, inter_size, num_experts, stream); + // expanded_active_expert_rows is not used + moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, fc1_scales, fc1_expert_biases, + fc1_result_ + total_past_rows_ * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, inter_size, hidden_size, + local_num_experts, fc1_activation_type, stream); + + moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size, + fc2_expert_weights, fc2_scales, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream); } template @@ -621,12 +653,12 @@ void CutlassMoeFCRunner::run_moe_fc( const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, const int hidden_size, const int inter_size, int num_experts, - int k, char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, - int* expert_for_source_row, cudaStream_t stream) { + int local_num_experts, int local_experts_start_index, int k, char* workspace_ptr, T* fc2_result, T* expert_scales, + int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream) { run_moe_fc(input_activations, gating_output, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_activation_type, - fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, k, workspace_ptr, - fc2_result, nullptr, num_rows, expert_scales, expanded_source_row_to_expanded_dest_row, - expert_for_source_row, stream); + fc2_expert_weights, fc2_scales, num_rows, hidden_size, inter_size, num_experts, local_num_experts, + local_experts_start_index, k, workspace_ptr, fc2_result, nullptr, num_rows, expert_scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, stream); } template @@ -642,6 +674,44 @@ void CutlassMoeFCRunner::compute_total_rows_before_expert total_rows_before_expert); } +template +void CutlassMoeFCRunner::dispatch_activations(int64_t* total_rows_before_expert, + int num_experts, int local_num_experts, + int local_experts_start_index, + cudaStream_t stream) { + total_rows_before_expert_host_.resize(num_experts); + cudaMemcpyAsync(total_rows_before_expert_host_.data(), total_rows_before_expert, num_experts * sizeof(int64_t), + cudaMemcpyDeviceToHost, stream); + + const int threads = std::min(1024, num_experts); + const int blocks = (num_experts + threads - 1) / threads; + + cudaEvent_t& copy_event = cuda_event_.Get(); + cudaEventCreateWithFlags(©_event, cudaEventDisableTiming); + cudaEventRecord(copy_event, stream); + + dispatch_activations_kernel<<>>(total_rows_before_expert, num_experts, + local_num_experts, local_experts_start_index); + + get_total_rows_info(local_experts_start_index, local_num_experts, total_past_rows_, total_covered_rows_); +} + +template +void CutlassMoeFCRunner::get_total_rows_info(int64_t experts_start_index, + int64_t local_num_experts, + int64_t& total_past_rows, + int64_t& total_covered_rows) { + int64_t experts_end_index = experts_start_index + local_num_experts - 1; + total_past_rows = 0; + + cudaEventSynchronize(cuda_event_.Get()); + + if (experts_start_index > 0) { + total_past_rows = total_rows_before_expert_host_[experts_start_index - 1]; + } + total_covered_rows = total_rows_before_expert_host_[experts_end_index] - total_past_rows; +} + // ========================== Permutation things ======================================= // Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 5cefe4fa5dc47..5cc2a3f79f003 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once @@ -20,6 +22,7 @@ #include #include "core/common/common.h" +#include "contrib_ops/cuda/bert/transformer_cuda_common.h" using namespace onnxruntime; @@ -111,20 +114,26 @@ class CutlassMoeFCRunner { void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, - int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, - T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, - cudaStream_t stream); + int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, + char* workspace_ptr, T* fc2_result, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, + int* expert_for_source_row, cudaStream_t stream); void run_moe_fc(const T* input_activations, const T* gating_output, const WeightType* fc1_expert_weights, const T* fc1_scales, const T* fc1_expert_biases, ActivationType fc1_activation_type, const WeightType* fc2_expert_weights, const T* fc2_scales, int num_rows, int hidden_size, - int inter_size, int num_experts, int k, char* workspace_ptr, T* fc2_result, - const bool* finished, int active_rows, T* expert_scales, + int inter_size, int num_experts, int local_num_experts, int local_experts_start_index, int k, + char* workspace_ptr, T* fc2_result, const bool* finished, int active_rows, T* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, cudaStream_t stream); void compute_total_rows_before_expert(const int* sorted_indices, int total_indices, int num_experts, int64_t* total_rows_before_expert, cudaStream_t stream); + void dispatch_activations(int64_t* total_rows_before_expert, int num_experts, int local_num_experts, + int local_experts_start_index, cudaStream_t stream); + + void get_total_rows_info(int64_t experts_start_index, int64_t local_num_experts, int64_t& total_past_rows, + int64_t& total_covered_rows); + private: void configure_ws_ptrs(char* ws_ptr, int num_rows, int hidden_size, int inter_size, int num_experts, int k); @@ -143,6 +152,14 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + + // Cuda events + contrib::cuda::AutoDestoryCudaEvent cuda_event_; + + int64_t total_past_rows_; + int64_t total_covered_rows_; + // TODO: use pinned memory + std::vector total_rows_before_expert_host_; }; template diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 6f2ffe7a0cc43..3f26a274109ad 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -30,6 +30,10 @@ REGISTER_KERNEL_TYPED(MLFloat16) using namespace ONNX_NAMESPACE; +template +MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +} + template Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); @@ -39,95 +43,9 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const Tensor* fc1_experts_bias_optional = context->Input(4); const Tensor* fc2_experts_bias_optional = context->Input(5); - const auto& input_dims = input->Shape().GetDims(); - const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); - - const int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; - const int64_t hidden_size = input_dims[input_dims.size() - 1]; - const int64_t num_experts = fc1_experts_weights_dims[0]; - const int64_t inter_size = fc1_experts_weights_dims[2]; - - // TODO: refactor to helper function. - if (fc1_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", - fc1_experts_weights_dims.size()); - } - if (fc2_experts_weights_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", - fc2_experts_weights_dims.size()); - } - if (fc1_experts_weights_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", - fc1_experts_weights_dims[1], " and ", hidden_size); - } - if (fc2_experts_weights_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[1] must be equal to inter_size, got ", fc2_experts_weights_dims[1], - " and ", inter_size); - } - if (fc1_experts_weights_dims[2] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", fc1_experts_weights_dims[2], - " and ", inter_size); - } - if (fc2_experts_weights_dims[2] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); - } - if (router_probs_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", - router_probs_dims.size()); - } - if (router_probs_dims[0] != num_rows) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", - router_probs_dims[0], " and ", num_rows); - } - if (router_probs_dims[1] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[1] must be equal to num_experts, got ", - router_probs_dims[1], " and ", num_experts); - } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); - } - if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); - } - if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { - const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); - const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); - if (fc1_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", - fc1_experts_bias_dims.size()); - } - if (fc2_experts_bias_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", - fc2_experts_bias_dims.size()); - } - if (fc1_experts_bias_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[0] must be equal to num_experts, got ", fc1_experts_bias_dims[0], - " and ", num_experts); - } - if (fc2_experts_bias_dims[0] != num_experts) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], - " and ", num_experts); - } - if (fc1_experts_bias_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); - } - if (fc2_experts_bias_dims[1] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", fc2_experts_bias_dims[1], - " and ", hidden_size); - } - } + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc1_experts_bias_optional, fc2_experts_bias_optional)); typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); @@ -138,12 +56,13 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm); size_t ws_size = - moe_runner.getWorkspaceSize(static_cast(num_rows), static_cast(hidden_size), - static_cast(inter_size), static_cast(num_experts), static_cast(k_)); - size_t fc2_output_size = k_ * num_rows * hidden_size * sizeof(CudaT); - size_t expert_scales_size = k_ * num_rows * sizeof(CudaT); - size_t expanded_source_row_to_expanded_dest_row_size = k_ * num_rows * sizeof(int); - size_t expert_for_source_row_size = k_ * num_rows * sizeof(int); + moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(k_)); + size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); + size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); + size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); + size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -170,8 +89,10 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { ? nullptr : reinterpret_cast(fc1_experts_bias_optional->template Data()), activation_type_, reinterpret_cast(fc2_experts_weights->template Data()), - std::move(fc2_scales_ptr), static_cast(num_rows), static_cast(hidden_size), - static_cast(inter_size), static_cast(num_experts), static_cast(k_), + std::move(fc2_scales_ptr), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(moe_params.inter_size), + static_cast(moe_params.num_experts), static_cast(moe_params.local_num_experts), + 0 /*local_experts_start_index_ used in sharded MoE*/, static_cast(k_), reinterpret_cast(work_space.get()), reinterpret_cast(fc2_output.get()), reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), @@ -186,7 +107,8 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { : reinterpret_cast(fc2_experts_bias_optional->template Data()), reinterpret_cast(expert_scales.get()), reinterpret_cast(expanded_source_row_to_expanded_dest_row.get()), - reinterpret_cast(expert_for_source_row.get()), static_cast(num_rows), static_cast(hidden_size), + reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), + static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index 8035568693814..c4d8c4dc64c57 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -4,6 +4,7 @@ #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" +#include "contrib_ops/cuda/moe/moe_base.h" #include "core/common/common.h" #include "core/providers/cuda/cuda_kernel.h" @@ -14,30 +15,10 @@ namespace cuda { using namespace onnxruntime::cuda; template -class MoE final : public CudaKernel { +class MoE final : public CudaKernel, public MoEBase { public: - explicit MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { - ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); - - std::string activation_type_str; - ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); - if (activation_type_str == "relu") { - activation_type_ = ort_fastertransformer::ActivationType::Relu; - } else if (activation_type_str == "gelu") { - activation_type_ = ort_fastertransformer::ActivationType::Gelu; - } else if (activation_type_str == "silu") { - activation_type_ = ort_fastertransformer::ActivationType::Silu; - } else if (activation_type_str == "identity") { - activation_type_ = ort_fastertransformer::ActivationType::Identity; - } else { - ORT_THROW("Unsupported MoE activation type: ", activation_type_str); - } - } + explicit MoE(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* ctx) const override; - - private: - int64_t k_; - ort_fastertransformer::ActivationType activation_type_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h new file mode 100644 index 0000000000000..f55a7cde2e208 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +enum class MoEParallelType { + None = 0, + ExpertSlicing = 1, +}; + +struct MoEParameters { + int64_t num_rows; + int64_t num_experts; + int64_t local_num_experts; + int64_t hidden_size; + int64_t inter_size; + MoEParallelType parallel_type; +}; + +class MoEBase { + public: + Status CheckInputs(MoEParameters& parameters, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc2_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_bias_optional) const { + const auto& input_dims = input->Shape().GetDims(); + const auto& router_probs_dims = router_probs->Shape().GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); + const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); + + int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; + int64_t hidden_size = input_dims[input_dims.size() - 1]; + int64_t local_num_experts = fc1_experts_weights_dims[0]; + int64_t num_experts = router_probs_dims[1]; + int64_t inter_size = fc1_experts_weights_dims[2]; + + if (fc1_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_weights_dims must be 3D, got ", + fc1_experts_weights_dims.size()); + } + if (fc2_experts_weights_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_weights_dims must be 3D, got ", + fc2_experts_weights_dims.size()); + } + if (fc1_experts_weights_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[1] must be equal to hidden_size, got ", + fc1_experts_weights_dims[1], " and ", hidden_size); + } + if (fc2_experts_weights_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[1] must be equal to inter_size, got ", + fc2_experts_weights_dims[1], + " and ", inter_size); + } + if (fc1_experts_weights_dims[2] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_weights_dims[2] must be equal to inter_size, got ", + fc1_experts_weights_dims[2], + " and ", inter_size); + } + if (fc2_experts_weights_dims[2] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", + fc2_experts_weights_dims[2], " and ", hidden_size); + } + if (router_probs_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims must be 2D, got ", + router_probs_dims.size()); + } + if (router_probs_dims[0] != num_rows) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "router_probs_dims[0] must be equal to num_rows, got ", + router_probs_dims[0], " and ", num_rows); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is set but fc2_experts_bias is not set"); + } + if (fc1_experts_bias_optional == nullptr && fc2_experts_bias_optional != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias is not set but fc2_experts_bias is set"); + } + if (fc1_experts_bias_optional != nullptr && fc2_experts_bias_optional != nullptr) { + const auto& fc1_experts_bias_dims = fc1_experts_bias_optional->Shape().GetDims(); + const auto& fc2_experts_bias_dims = fc2_experts_bias_optional->Shape().GetDims(); + if (fc1_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_bias_dims must be 2D, got ", + fc1_experts_bias_dims.size()); + } + if (fc2_experts_bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_bias_dims must be 2D, got ", + fc2_experts_bias_dims.size()); + } + if (fc1_experts_bias_dims[0] != local_num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[0] must be equal to local_num_experts, got ", + fc1_experts_bias_dims[0], + " and ", local_num_experts); + } + if (fc2_experts_bias_dims[0] != num_experts) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[0] must be equal to num_experts, got ", + fc2_experts_bias_dims[0], + " and ", num_experts); + } + if (fc1_experts_bias_dims[1] != inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc1_experts_bias_dims[1] must be equal to inter_size, got ", + fc1_experts_bias_dims[1], + " and ", inter_size); + } + if (fc2_experts_bias_dims[1] != hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "fc2_experts_bias_dims[1] must be equal to hidden_size, got ", + fc2_experts_bias_dims[1], + " and ", hidden_size); + } + } + + parameters.num_rows = num_rows; + parameters.num_experts = num_experts; + parameters.local_num_experts = local_num_experts; + parameters.hidden_size = hidden_size; + parameters.inter_size = inter_size; + if (num_experts == local_num_experts) { + parameters.parallel_type = MoEParallelType::None; + } else if (num_experts > local_num_experts) { + parameters.parallel_type = MoEParallelType::ExpertSlicing; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_experts must be greater than or equal to local_num_experts, got ", + num_experts, " and ", local_num_experts); + } + + return Status::OK(); + } + + protected: + MoEBase(const OpKernelInfo& op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("k", &k_).IsOK()); + + std::string activation_type_str; + ORT_ENFORCE(op_kernel_info.GetAttr("activation_type", &activation_type_str).IsOK()); + if (activation_type_str == "relu") { + activation_type_ = ort_fastertransformer::ActivationType::Relu; + } else if (activation_type_str == "gelu") { + activation_type_ = ort_fastertransformer::ActivationType::Gelu; + } else if (activation_type_str == "silu") { + activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "identity") { + activation_type_ = ort_fastertransformer::ActivationType::Identity; + } else { + ORT_THROW("Unsupported MoE activation type: ", activation_type_str); + } + } + + int64_t k_; + ort_fastertransformer::ActivationType activation_type_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 59adfc523c860..4aa43f5de1cd5 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -80,6 +80,60 @@ void RegisterCollectiveOps() { propagateShapeAndTypeFromFirstInput(ctx); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(ShardedMoE) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("activation_type", + "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + AttributeProto::STRING, + std::string("relu")) + .Attr("k", + "Number of top experts to select from expert pool", + AttributeProto::INT, + static_cast(1)) + .Attr("local_experts_start_index", + "The start index of local experts", + AttributeProto::INT, + static_cast(-1)) + .Input(0, + "input", + "2D input tensor with shape (num_rows, hidden_size) or " + "3D input tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "router_probs", + "2D input tensor with shape (num_rows, num_experts)", + "T") + .Input(2, + "fc1_experts_weights", + "3D input tensor with shape (local_num_experts, hidden_size, inter_size)", + "T") + .Input(3, + "fc2_experts_weights", + "3D input tensor with shape (local_num_experts, inter_size, hidden_size)", + "T") + .Input(4, + "fc1_experts_bias", + "2D optional input tensor with shape (local_num_experts, inter_size)", + "T", + OpSchema::Optional) + .Input(5, + "fc2_experts_bias", + "2D optional input tensor with shape (num_experts, hidden_size)", + "T", + OpSchema::Optional) + .Output(0, + "output", + "2D input tensor with shape (num_rows, hidden_size) or " + "3D input tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)"}, + "Constrain input and output types to float or float16 tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/onnxruntime/test/python/transformers/sharded_moe/run_script.sh b/onnxruntime/test/python/transformers/sharded_moe/run_script.sh new file mode 100644 index 0000000000000..c591d777c4287 --- /dev/null +++ b/onnxruntime/test/python/transformers/sharded_moe/run_script.sh @@ -0,0 +1,10 @@ + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode 4 --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python test_sharded_moe.py" + +set -x +$CMD diff --git a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py new file mode 100644 index 0000000000000..af835d2906e87 --- /dev/null +++ b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py @@ -0,0 +1,262 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +from mpi4py import MPI +from onnx import TensorProto, helper + +import onnxruntime + +np.random.seed(3) + +comm = MPI.COMM_WORLD + + +def get_rank(): + return comm.Get_rank() + + +def get_size(): + return comm.Get_size() + + +def barrier(): + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) + + +def broadcast(data): + comm = MPI.COMM_WORLD + comm.broadcast(data, root=0) + + +local_rank = get_rank() + +ORT_DTYPE = TensorProto.FLOAT16 +NP_TYPE = np.float16 if ORT_DTYPE == TensorProto.FLOAT16 else np.float32 +THRESHOLD = 1e-3 + + +def create_moe_onnx_graph( + num_rows, + num_experts, + local_num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + local_experts_start_index=-1, +): + use_sharded_moe = local_experts_start_index >= 0 + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc2_experts_weights", + "fc1_experts_bias", + "fc2_experts_bias", + ], + ["output"], + "MoE_0", + k=1, + activation_type="gelu", + domain="com.microsoft", + ) + if not use_sharded_moe + else helper.make_node( + "ShardedMoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc2_experts_weights", + "fc1_experts_bias", + "fc2_experts_bias", + ], + ["output"], + "MoE_0", + k=1, + activation_type="gelu", + local_experts_start_index=local_experts_start_index, + domain="com.microsoft", + ), + ] + + fc1_shape = [local_num_experts, hidden_size, inter_size] + fc2_shape = [local_num_experts, inter_size, hidden_size] + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.flatten(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.flatten(), + raw=False, + ), + ] + + fc1_bias_shape = [local_num_experts, inter_size] + fc2_bias_shape = [num_experts, hidden_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_bias", + ORT_DTYPE, + fc1_bias_shape, + fc1_experts_bias.flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + ORT_DTYPE, + fc2_bias_shape, + fc2_experts_bias.flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def test_moe_with_expert_slicing( + hidden_size, + inter_size, + num_experts, + num_rows, +): + local_experts_start_index = local_rank * num_experts // get_size() + + fc1_experts_weights_all = np.random.rand(num_experts, hidden_size, inter_size).astype(NP_TYPE) + fc2_experts_weights_all = np.random.rand(num_experts, inter_size, hidden_size).astype(NP_TYPE) + fc1_experts_bias_all = np.random.rand(num_experts, inter_size).astype(NP_TYPE) + fc2_experts_bias_all = np.random.rand(num_experts, hidden_size).astype(NP_TYPE) + + onnx_model_full = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights_all, + fc2_experts_weights_all, + fc1_experts_bias_all, + fc2_experts_bias_all, + ) + + fc1_experts_weights = fc1_experts_weights_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : + ] + fc2_experts_weights = fc2_experts_weights_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), :, : + ] + fc1_experts_bias = fc1_experts_bias_all[ + local_experts_start_index : local_experts_start_index + num_experts // get_size(), : + ] + + onnx_model_local = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts // get_size(), + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias_all, + local_experts_start_index, + ) + + sess_options = onnxruntime.SessionOptions() + cuda_provider_options = {"device_id": local_rank} + execution_providers = [("CUDAExecutionProvider", cuda_provider_options)] + + ort_session = onnxruntime.InferenceSession(onnx_model_full, sess_options, providers=execution_providers) + ort_session_local = onnxruntime.InferenceSession(onnx_model_local, sess_options, providers=execution_providers) + + ort_inputs = { + ort_session.get_inputs()[0].name: np.random.rand(num_rows, hidden_size).astype(NP_TYPE), + ort_session.get_inputs()[1].name: np.random.rand(num_rows, num_experts).astype(NP_TYPE), + } + + output = ort_session.run(None, ort_inputs) + sharded_output = ort_session_local.run(None, ort_inputs) + + assert np.allclose(output[0], sharded_output[0], atol=THRESHOLD, rtol=THRESHOLD) + + print_out( + "hidden_size: ", + hidden_size, + " inter_size: ", + inter_size, + " num_experts: ", + num_experts, + " num_rows: ", + num_rows, + " world_size: ", + get_size(), + " Parity: OK", + ) + + +class TestMoE(unittest.TestCase): + def test_moe_expert_slicing(self): + for hidden_size in [16, 128]: + for inter_size in [512, 1024]: + for num_experts in [8, 16, 32]: + for num_rows in [16, 128, 512]: + test_moe_with_expert_slicing( + hidden_size, + inter_size, + num_experts, + num_rows, + ) + + +if __name__ == "__main__": + unittest.main()