Skip to content

Commit

Permalink
add test files
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Nov 20, 2023
1 parent 31c5daa commit 410c28d
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 11 deletions.
9 changes: 9 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
fc1_experts_bias_optional, fc2_experts_bias_optional));

// print out parameters for each rank
std::cout << "nccl rank:" << nccl_->Rank() << " " << "local_experts_start_index: " << local_experts_start_index_ << std::endl;
std::cout << "nccl rank:" << nccl_->Rank() << " " << "num_rows: " << moe_params.num_rows << std::endl;
std::cout << "nccl rank:" << nccl_->Rank() << " " << "hidden_size: " << moe_params.hidden_size << std::endl;
std::cout << "nccl rank:" << nccl_->Rank() << " " << "inter_size: " << moe_params.inter_size << std::endl;
std::cout << "nccl rank:" << nccl_->Rank() << " " << "num_experts: " << moe_params.num_experts << std::endl;
std::cout << "nccl rank:" << nccl_->Rank() << " " << "local_num_experts: " << moe_params.local_num_experts << std::endl;
std::cout << "nccl rank:" << nccl_->Rank() << " " << "k: " << k_ << std::endl;

typedef typename ToCudaType<T>::MappedType CudaT;
auto stream = context->GetComputeStream();

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,

Expand Down
26 changes: 15 additions & 11 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <cuda.h>
#include <cuda_fp16.h>
Expand Down Expand Up @@ -505,8 +507,6 @@ __global__ void compute_total_rows_before_expert_kernel(const int* sorted_expert
__global__ void dispatch_activations_kernel(int64_t*& total_rows_before_expert, int num_experts,
int local_num_experts, int local_experts_start_index,
int& total_past_rows, int& total_covered_rows) {
// permuted_experts_ : 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6,
// total_rows_before_expert_ : 1, 5, 12, 19, 25, 31, 32, 32,
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
const int local_experts_end_index = local_experts_start_index + local_num_experts - 1;

Expand All @@ -522,6 +522,8 @@ __global__ void dispatch_activations_kernel(int64_t*& total_rows_before_expert,

template <typename T, typename WeightType, typename Enable>
CutlassMoeFCRunner<T, WeightType, Enable>::CutlassMoeFCRunner(int sm_version) {
total_past_rows_ = 0;
total_covered_rows_ = 0;
moe_gemm_runner_.initialize(sm_version);
}

Expand Down Expand Up @@ -615,32 +617,34 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
configure_ws_ptrs(workspace_ptr, num_rows, hidden_size, inter_size, num_experts, k);
topk_gating_softmax_kernelLauncher<T>(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row,
source_rows_, num_rows, num_experts, k, stream);
// source_rows_: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
// expert_for_source_row : 3, 4, 1, 1, 2, 1, 2, 4, 5, 4, 3, 3, 5, 2, 5, 4, 4, 3, 3, 3, 2, 5, 5, 0, 2, 3, 4, 6, 2, 1, 5, 2,
print_cuda_buffer("source_rows_", source_rows_, num_rows);
print_cuda_buffer("expert_for_source_row", expert_for_source_row, num_rows);

const int sorter_ws_size_bytes = static_cast<int>(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)));
sorter_.run((void*)fc1_result_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_,
permuted_rows_, k * num_rows, stream);
// permuted_experts_ : 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6,
// permuted_rows_ : 23, 2, 3, 5, 29, 4, 6, 13, 20, 24, 28, 31, 0, 10, 11, 17, 18, 19, 25, 1, 7, 9, 15, 16, 26, 8, 12, 14, 21, 22, 30, 27,
// source_rows_ : 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
print_cuda_buffer("permuted_experts_", permuted_experts_, num_rows);
print_cuda_buffer("permuted_rows_", permuted_rows_, num_rows);
print_cuda_buffer("source_rows_", source_rows_, num_rows);

initialize_moe_routing_kernelLauncher(input_activations, permuted_data_, permuted_rows_,
expanded_source_row_to_expanded_dest_row, num_rows, active_rows, hidden_size, k,
stream);
// expanded_source_row_to_expanded_dest_row : 12, 19, 1, 2, 5, 3, 6, 20, 25, 21, 13, 14, 26, 7, 27, 22, 23, 15, 16, 17, 8, 28, 29, 0, 9, 18, 24, 31, 10, 4, 30, 11,
// permuted_rows_ : 23, 2, 3, 5, 29, 4, 6, 13, 20, 24, 28, 31, 0, 10, 11, 17, 18, 19, 25, 1, 7, 9, 15, 16, 26, 8, 12, 14, 21, 22, 30, 27,
print_cuda_buffer("expanded_source_row_to_expanded_dest_row", expanded_source_row_to_expanded_dest_row, num_rows);
print_cuda_buffer("permuted_rows_", permuted_rows_, num_rows);

const int expanded_active_expert_rows = k * active_rows;
compute_total_rows_before_expert(permuted_experts_, expanded_active_expert_rows, num_experts,
total_rows_before_expert_, stream);
// total_rows_before_expert_ : 1, 5, 12, 19, 25, 31, 32, 32,

print_cuda_buffer("total_rows_before_expert_", total_rows_before_expert_, num_experts);
if (local_num_experts < num_experts) {
dispatch_activations(total_rows_before_expert_, num_experts, local_num_experts, local_experts_start_index,
total_past_rows_, total_covered_rows_, stream);
// bugbug: use cuda event
// TODO: use cuda event
cudaDeviceSynchronize();
std::cout << "total_past_rows_ = " << total_past_rows_ << ", total_covered_rows_ = " << total_covered_rows_
<< std::endl;
}

// expanded_active_expert_rows is not used
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/test/python/transformers/sharded_moe/run_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

MPI="mpirun --allow-run-as-root
-mca btl_openib_warn_no_device_params_found 0 -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0
--tag-output --npernode 2 --bind-to numa
-x MIOPEN_FIND_MODE=1"

CMD="$MPI python test_sharded_moe.py"

set -x
$CMD
205 changes: 205 additions & 0 deletions onnxruntime/test/python/transformers/sharded_moe/test_sharded_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import os
from mpi4py import MPI
import onnxruntime
import numpy as np
from onnx import TensorProto, helper

np.random.seed(42)

comm = MPI.COMM_WORLD

def get_rank():
return comm.Get_rank()


def get_size():
return comm.Get_size()


def barrier():
comm.Barrier()


def print_out(*args):
if get_rank() == 0:
print(*args)


def broadcast(data):
comm = MPI.COMM_WORLD
comm.broadcast(data, root=0)

local_rank = get_rank()

ORT_DTYPE = TensorProto.FLOAT16
NP_TYPE = np.float16 if ORT_DTYPE == TensorProto.FLOAT16 else np.float32
THRESHOLD = 3e-2

def create_moe_onnx_graph(
num_rows,
num_experts,
local_num_experts,
hidden_size,
inter_size,
fc1_experts_weights,
fc2_experts_weights,
fc1_experts_bias,
fc2_experts_bias,
local_experts_start_index = -1,
):
use_sharded_moe = True if local_experts_start_index >= 0 else False
nodes = [
helper.make_node(
"MoE",
[
"input",
"router_probs",
"fc1_experts_weights",
"fc2_experts_weights",
"fc1_experts_bias",
"fc2_experts_bias",
],
["output"],
"MoE_0",
k=1,
activation_type="gelu",
domain="com.microsoft",
) if not use_sharded_moe else helper.make_node(
"ShardedMoE",
[
"input",
"router_probs",
"fc1_experts_weights",
"fc2_experts_weights",
"fc1_experts_bias",
"fc2_experts_bias",
],
["output"],
"MoE_0",
k=1,
activation_type="gelu",
local_experts_start_index=local_experts_start_index,
domain="com.microsoft",
),
]

fc1_shape = [local_num_experts, hidden_size, inter_size]
fc2_shape = [local_num_experts, inter_size, hidden_size]

initializers = [
helper.make_tensor(
"fc1_experts_weights",
ORT_DTYPE,
fc1_shape,
fc1_experts_weights.flatten(),
raw=False,
),
helper.make_tensor(
"fc2_experts_weights",
ORT_DTYPE,
fc2_shape,
fc2_experts_weights.flatten(),
raw=False,
),
]

fc1_bias_shape = [local_num_experts, inter_size]
fc2_bias_shape = [local_num_experts, hidden_size]
initializers.extend(
[
helper.make_tensor(
"fc1_experts_bias",
ORT_DTYPE,
fc1_bias_shape,
fc1_experts_bias.flatten().tolist(),
raw=False,
),
helper.make_tensor(
"fc2_experts_bias",
ORT_DTYPE,
fc2_bias_shape,
fc2_experts_bias.flatten().tolist(),
raw=False,
),
]
)

graph_inputs = [
helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]),
]

graph_inputs.append(
helper.make_tensor_value_info(
"router_probs",
ORT_DTYPE,
[num_rows, num_experts],
)
)

graph_outputs = [
helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]),
]

graph = helper.make_graph(
nodes,
"MoE_Graph",
graph_inputs,
graph_outputs,
initializers,
)

model = helper.make_model(graph)
return model.SerializeToString()

# two gpus
def main():
hidden_size = 32
inter_size = 64
num_experts = 16
num_rows = 32
local_experts_start_index = local_rank * num_experts // get_size()

# create weights and bias
fc1_experts_weights_all = np.random.rand(num_experts, hidden_size, inter_size).astype(NP_TYPE)
fc2_experts_weights_all = np.random.rand(num_experts, inter_size, hidden_size).astype(NP_TYPE)
fc1_experts_bias_all = np.random.rand(num_experts, inter_size).astype(NP_TYPE)
fc2_experts_bias_all = np.random.rand(num_experts, hidden_size).astype(NP_TYPE)

# expert slicing by local rank
fc1_experts_weights = fc1_experts_weights_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :, :]
fc2_experts_weights = fc2_experts_weights_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :, :]
fc1_experts_bias = fc1_experts_bias_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :]
fc2_experts_bias = fc2_experts_bias_all[local_experts_start_index:local_experts_start_index + num_experts // get_size(), :]

# create onnx graph
onnx_model = create_moe_onnx_graph(
num_rows,
num_experts,
num_experts // get_size(),
hidden_size,
inter_size,
fc1_experts_weights,
fc2_experts_weights,
fc1_experts_bias,
fc2_experts_bias,
local_experts_start_index,
)

sess_options = onnxruntime.SessionOptions()
cuda_provider_options = {"enable_skip_layer_norm_strict_mode": False, "device_id": local_rank}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
execution_providers = [("CUDAExecutionProvider", cuda_provider_options)]

ort_session = onnxruntime.InferenceSession(onnx_model, sess_options, providers=execution_providers)

input_name = ort_session.get_inputs()[0].name
router_probs_name = ort_session.get_inputs()[1].name
input_data = np.random.rand(num_rows, hidden_size).astype(NP_TYPE)
router_probs = np.random.rand(num_rows, num_experts).astype(NP_TYPE)
ort_inputs = {input_name: input_data, router_probs_name: router_probs}

output = ort_session.run(None, {input_name: input_data, router_probs_name: router_probs})


if __name__ == "__main__":
main()

0 comments on commit 410c28d

Please sign in to comment.