From 410c28dde6d51f520a8ac090c72b9e68e08bb554 Mon Sep 17 00:00:00 2001 From: Your Date: Mon, 20 Nov 2023 23:28:57 +0000 Subject: [PATCH] add test files --- .../cuda/collective/sharded_moe.cc | 9 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 3 + .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 26 ++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 2 + .../transformers/sharded_moe/run_script.sh | 10 + .../sharded_moe/test_sharded_moe.py | 205 ++++++++++++++++++ 6 files changed, 244 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/test/python/transformers/sharded_moe/run_script.sh create mode 100644 onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 1bb433a13e942..1f5531a010209 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -50,6 +50,15 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights, fc1_experts_bias_optional, fc2_experts_bias_optional)); + // print out parameters for each rank + std::cout << "nccl rank:" << nccl_->Rank() << " " << "local_experts_start_index: " << local_experts_start_index_ << std::endl; + std::cout << "nccl rank:" << nccl_->Rank() << " " << "num_rows: " << moe_params.num_rows << std::endl; + std::cout << "nccl rank:" << nccl_->Rank() << " " << "hidden_size: " << moe_params.hidden_size << std::endl; + std::cout << "nccl rank:" << nccl_->Rank() << " " << "inter_size: " << moe_params.inter_size << std::endl; + std::cout << "nccl rank:" << nccl_->Rank() << " " << "num_experts: " << moe_params.num_experts << std::endl; + std::cout << "nccl rank:" << nccl_->Rank() << " " << "local_num_experts: " << moe_params.local_num_experts << std::endl; + std::cout << "nccl rank:" << nccl_->Rank() << " " << "k: " << k_ << std::endl; + typedef typename ToCudaType::MappedType CudaT; auto stream = context->GetComputeStream(); diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 45b1716e9dfc0..01ccdbcbd3765 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -365,6 +365,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index e37436ec9d214..79c9f7b0f8729 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include #include @@ -505,8 +507,6 @@ __global__ void compute_total_rows_before_expert_kernel(const int* sorted_expert __global__ void dispatch_activations_kernel(int64_t*& total_rows_before_expert, int num_experts, int local_num_experts, int local_experts_start_index, int& total_past_rows, int& total_covered_rows) { - // permuted_experts_ : 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, - // total_rows_before_expert_ : 1, 5, 12, 19, 25, 31, 32, 32, const int expert = blockIdx.x * blockDim.x + threadIdx.x; const int local_experts_end_index = local_experts_start_index + local_num_experts - 1; @@ -522,6 +522,8 @@ __global__ void dispatch_activations_kernel(int64_t*& total_rows_before_expert, template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version) { + total_past_rows_ = 0; + total_covered_rows_ = 0; moe_gemm_runner_.initialize(sm_version); } @@ -615,32 +617,34 @@ void CutlassMoeFCRunner::run_moe_fc( configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k); topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, source_rows_, num_rows, num_experts, k, stream); - // source_rows_: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - // expert_for_source_row : 3, 4, 1, 1, 2, 1, 2, 4, 5, 4, 3, 3, 5, 2, 5, 4, 4, 3, 3, 3, 2, 5, 5, 0, 2, 3, 4, 6, 2, 1, 5, 2, + print_cuda_buffer("source_rows_", source_rows_, num_rows); + print_cuda_buffer("expert_for_source_row", expert_for_source_row, num_rows); const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, permuted_rows_, k * num_rows, stream); - // permuted_experts_ : 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, - // permuted_rows_ : 23, 2, 3, 5, 29, 4, 6, 13, 20, 24, 28, 31, 0, 10, 11, 17, 18, 19, 25, 1, 7, 9, 15, 16, 26, 8, 12, 14, 21, 22, 30, 27, - // source_rows_ : 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + print_cuda_buffer("permuted_experts_", permuted_experts_, num_rows); + print_cuda_buffer("permuted_rows_", permuted_rows_, num_rows); + print_cuda_buffer("source_rows_", source_rows_, num_rows); initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_, expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k, stream); - // expanded_source_row_to_expanded_dest_row : 12, 19, 1, 2, 5, 3, 6, 20, 25, 21, 13, 14, 26, 7, 27, 22, 23, 15, 16, 17, 8, 28, 29, 0, 9, 18, 24, 31, 10, 4, 30, 11, - // permuted_rows_ : 23, 2, 3, 5, 29, 4, 6, 13, 20, 24, 28, 31, 0, 10, 11, 17, 18, 19, 25, 1, 7, 9, 15, 16, 26, 8, 12, 14, 21, 22, 30, 27, + print_cuda_buffer("expanded_source_row_to_expanded_dest_row", expanded_source_row_to_expanded_dest_row, num_rows); + print_cuda_buffer("permuted_rows_", permuted_rows_, num_rows); const int expanded_active_expert_rows = k * active_rows; compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts, total_rows_before_expert_, stream); - // total_rows_before_expert_ : 1, 5, 12, 19, 25, 31, 32, 32, + print_cuda_buffer("total_rows_before_expert_", total_rows_before_expert_, num_experts); if (local_num_experts < num_experts) { dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index, total_past_rows_, total_covered_rows_, stream); - // bugbug: use cuda event + // TODO: use cuda event cudaDeviceSynchronize(); + std::cout << "total_past_rows_ = " << total_past_rows_ << ", total_covered_rows_ = " << total_covered_rows_ + << std::endl; } // expanded_active_expert_rows is not used diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index a2b9f6e229522..2c21e5f827120 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once diff --git a/onnxruntime/test/python/transformers/sharded_moe/run_script.sh b/onnxruntime/test/python/transformers/sharded_moe/run_script.sh new file mode 100644 index 0000000000000..8eba627f79cdb --- /dev/null +++ b/onnxruntime/test/python/transformers/sharded_moe/run_script.sh @@ -0,0 +1,10 @@ + +MPI="mpirun --allow-run-as-root + -mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 + --tag-output --npernode 2 --bind-to numa + -x MIOPEN_FIND_MODE=1" + +CMD="$MPI python test_sharded_moe.py" + +set -x +$CMD diff --git a/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py new file mode 100644 index 0000000000000..a9977ee5463dc --- /dev/null +++ b/onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py @@ -0,0 +1,205 @@ +import os +from mpi4py import MPI +import onnxruntime +import numpy as np +from onnx import TensorProto, helper + +np.random.seed(42) + +comm = MPI.COMM_WORLD + +def get_rank(): + return comm.Get_rank() + + +def get_size(): + return comm.Get_size() + + +def barrier(): + comm.Barrier() + + +def print_out(*args): + if get_rank() == 0: + print(*args) + + +def broadcast(data): + comm = MPI.COMM_WORLD + comm.broadcast(data, root=0) + +local_rank = get_rank() + +ORT_DTYPE = TensorProto.FLOAT16 +NP_TYPE = np.float16 if ORT_DTYPE == TensorProto.FLOAT16 else np.float32 +THRESHOLD = 3e-2 + +def create_moe_onnx_graph( + num_rows, + num_experts, + local_num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + local_experts_start_index = -1, +): + use_sharded_moe = True if local_experts_start_index >= 0 else False + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc2_experts_weights", + "fc1_experts_bias", + "fc2_experts_bias", + ], + ["output"], + "MoE_0", + k=1, + activation_type="gelu", + domain="com.microsoft", + ) if not use_sharded_moe else helper.make_node( + "ShardedMoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc2_experts_weights", + "fc1_experts_bias", + "fc2_experts_bias", + ], + ["output"], + "MoE_0", + k=1, + activation_type="gelu", + local_experts_start_index=local_experts_start_index, + domain="com.microsoft", + ), + ] + + fc1_shape = [local_num_experts, hidden_size, inter_size] + fc2_shape = [local_num_experts, inter_size, hidden_size] + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.flatten(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.flatten(), + raw=False, + ), + ] + + fc1_bias_shape = [local_num_experts, inter_size] + fc2_bias_shape = [local_num_experts, hidden_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_experts_bias", + ORT_DTYPE, + fc1_bias_shape, + fc1_experts_bias.flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_bias", + ORT_DTYPE, + fc2_bias_shape, + fc2_experts_bias.flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [num_rows, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + +# two gpus +def main(): + hidden_size = 32 + inter_size = 64 + num_experts = 16 + num_rows = 32 + local_experts_start_index = local_rank * num_experts // get_size() + + # create weights and bias + fc1_experts_weights_all = np.random.rand(num_experts, hidden_size, inter_size).astype(NP_TYPE) + fc2_experts_weights_all = np.random.rand(num_experts, inter_size, hidden_size).astype(NP_TYPE) + fc1_experts_bias_all = np.random.rand(num_experts, inter_size).astype(NP_TYPE) + fc2_experts_bias_all = np.random.rand(num_experts, hidden_size).astype(NP_TYPE) + + # expert slicing by local rank + fc1_experts_weights = fc1_experts_weights_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :, :] + fc2_experts_weights = fc2_experts_weights_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :, :] + fc1_experts_bias = fc1_experts_bias_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :] + fc2_experts_bias = fc2_experts_bias_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :] + + # create onnx graph + onnx_model = create_moe_onnx_graph( + num_rows, + num_experts, + num_experts // get_size(), + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc1_experts_bias, + fc2_experts_bias, + local_experts_start_index, + ) + + sess_options = onnxruntime.SessionOptions() + cuda_provider_options = {"enable_skip_layer_norm_strict_mode": False, "device_id": local_rank} + provider_options = {"CUDAExecutionProvider": cuda_provider_options} + execution_providers = [("CUDAExecutionProvider", cuda_provider_options)] + + ort_session = onnxruntime.InferenceSession(onnx_model, sess_options, providers=execution_providers) + + input_name = ort_session.get_inputs()[0].name + router_probs_name = ort_session.get_inputs()[1].name + input_data = np.random.rand(num_rows, hidden_size).astype(NP_TYPE) + router_probs = np.random.rand(num_rows, num_experts).astype(NP_TYPE) + ort_inputs = {input_name: input_data, router_probs_name: router_probs} + + output = ort_session.run(None, {input_name: input_data, router_probs_name: router_probs}) + + +if __name__ == "__main__": + main() \ No newline at end of file