Skip to content

Commit

Permalink
fix build error
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Nov 20, 2023
1 parent 7906e19 commit 31c5daa
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 44 deletions.
25 changes: 0 additions & 25 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,6 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))

static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
if (type == DataTypeImpl::GetType<uint8_t>()) {
return ncclUint8;
} else if (type == DataTypeImpl::GetType<bool>()) {
// CUDA bool is 8-bit large.
return ncclUint8;
} else if (type == DataTypeImpl::GetType<int8_t>()) {
return ncclInt8;
} else if (type == DataTypeImpl::GetType<int32_t>()) {
return ncclInt32;
} else if (type == DataTypeImpl::GetType<int64_t>()) {
return ncclInt64;
} else if (type == DataTypeImpl::GetType<MLFloat16>()) {
return ncclFloat16;
} else if (type == DataTypeImpl::GetType<float>()) {
return ncclFloat32;
} else if (type == DataTypeImpl::GetType<double>()) {
return ncclFloat64;
} else {
ORT_THROW("Tensor type not supported in NCCL.");
}
}

namespace IPC {
#define FLLOG LOGS_DEFAULT(VERBOSE)
#define FLLOGERRNO LOGS_DEFAULT(WARNING) << "error:" << strerror(errno)
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))

static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
if (type == DataTypeImpl::GetType<uint8_t>()) {
return ncclUint8;
} else if (type == DataTypeImpl::GetType<bool>()) {
// CUDA bool is 8-bit large.
return ncclUint8;
} else if (type == DataTypeImpl::GetType<int8_t>()) {
return ncclInt8;
} else if (type == DataTypeImpl::GetType<int32_t>()) {
return ncclInt32;
} else if (type == DataTypeImpl::GetType<int64_t>()) {
return ncclInt64;
} else if (type == DataTypeImpl::GetType<MLFloat16>()) {
return ncclFloat16;
} else if (type == DataTypeImpl::GetType<float>()) {
return ncclFloat32;
} else if (type == DataTypeImpl::GetType<double>()) {
return ncclFloat64;
} else {
ORT_THROW("Tensor type not supported in NCCL.");
}
}

// -----------------------------------------------------------------------
// Defines a new version of nccl classes
// that independent with training::DistributedRunContext, only rely on MPI
Expand Down
28 changes: 14 additions & 14 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,20 @@ Status ShardedMoE<T>::ComputeInternal(OpKernelContext* context) const {
int total_covered_rows = 0;
moe_runner.get_local_rows_info(total_past_rows, total_covered_rows);

CheckIfMemoryOnCurrentGpuDevice(fc2_output.get());

NCCL_RETURN_IF_ERROR(ncclGroupStart());
size_t fc2_output_offset = total_past_rows * hidden_size_ * sizeof(CudaT);
char* fc2_output_ptr = reinterpret_cast<char*>(fc2_output.get()) + fc2_output_offset;
ncclDataType_t dtype = GetNcclDataType<T>();
NCCL_RETURN_IF_ERROR(ncclBroadcast(fc2_output_ptr,
fc2_output_ptr,
total_covered_rows * sizeof(CudaT),
dtype,
nccl_->Rank(),
comm,
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
if (total_covered_rows > 0) {
NCCL_RETURN_IF_ERROR(ncclGroupStart());
size_t fc2_output_offset = total_past_rows * moe_params.hidden_size * sizeof(CudaT);
char* fc2_output_ptr = reinterpret_cast<char*>(fc2_output.get()) + fc2_output_offset;
ncclDataType_t dtype = GetNcclDataType(input->DataType());
NCCL_RETURN_IF_ERROR(ncclBroadcast(fc2_output_ptr,
fc2_output_ptr,
total_covered_rows * sizeof(CudaT),
dtype,
nccl_->Rank(),
comm,
Stream(context)));
NCCL_RETURN_IF_ERROR(ncclGroupEnd());
}

ort_fastertransformer::finalize_moe_routing_kernelLauncher(
reinterpret_cast<CudaT*>(fc2_output.get()), reinterpret_cast<CudaT*>(output->template MutableData<T>()),
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -644,16 +644,16 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::run_moe_fc(
}

// expanded_active_expert_rows is not used
moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows * hidden_size,
moe_gemm_runner_.moe_gemm_bias_act(permuted_data_ + total_past_rows_ * hidden_size,
fc1_expert_weights, fc1_scales, fc1_expert_biases,
fc1_result_ + total_past_rows * inter_size,
fc1_result_ + total_past_rows_ * inter_size,
total_rows_before_expert_ + local_experts_start_index,
expanded_active_expert_rows, inter_size, hidden_size,
local_num_experts, fc1_activation_type, stream);

moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows * inter_size,
moe_gemm_runner_.moe_gemm(fc1_result_ + total_past_rows_ * inter_size,
fc2_expert_weights, fc2_scales,
fc2_result + total_past_rows * hidden_size,
fc2_result + total_past_rows_ * hidden_size,
total_rows_before_expert_ + local_experts_start_index,
expanded_active_expert_rows, hidden_size, inter_size, local_num_experts, stream);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class CutlassMoeFCRunner {
int64_t* total_rows_before_expert, cudaStream_t stream);

void dispatch_activations(int64_t*& total_rows_before_expert, int num_experts, int local_num_experts,
int local_experts_start_index, int& total_past_rows, cudaStream_t stream);
int local_experts_start_index, int& total_past_rows, int& total_covered_rows,
cudaStream_t stream);

void get_local_rows_info(int& total_past_rows, int& total_covered_rows) {
// cudaDeviceSynchronize();
Expand Down

0 comments on commit 31c5daa

Please sign in to comment.