Skip to content

Commit

Permalink
[ROCm] Remove MPI dependency and collectives to use NCCL (#19830)
Browse files Browse the repository at this point in the history
### Description
* Remove MPI dependency to use NCCL AllReduce, etc.
* Exclude unsupported collectives in hipify
  • Loading branch information
mindest authored Mar 20, 2024
1 parent 6fe0206 commit 3dfe4a5
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 20 deletions.
24 changes: 6 additions & 18 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,18 @@ set(contrib_ops_excluded_files
"bert/group_query_attention.cc"
"bert/group_query_attention_impl.h"
"bert/group_query_attention_impl.cu"
"collective/distributed_*"
"collective/shard*"
)

if (NOT onnxruntime_ENABLE_ATEN)
list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc")
endif()
if (NOT onnxruntime_USE_NCCL)
# Those are string patterns to exclude. Do NOT use stars such as
# collective/*.cc or *.h.
list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc")
else()
# moe not supported for ROCm EP
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h")
list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc")
endif()

if (NOT onnxruntime_ENABLE_ATEN)
list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc")
endif()

set(provider_excluded_files
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather);
#endif

#if defined(USE_MPI) && defined(ORT_USE_NCCL)
#ifdef ORT_USE_NCCL
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll);
Expand Down Expand Up @@ -311,7 +311,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather)>,
#endif

#if defined(USE_MPI) && defined(ORT_USE_NCCL)
#ifdef ORT_USE_NCCL
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ void NcclService::Initialize() {
// CPUs
// Other devices

#ifdef USE_MPI
const int mpi_rank = onnxruntime::training::MPIContext::GetInstance().GetWorldRank();
const int mpi_local_rank = onnxruntime::training::MPIContext::GetInstance().GetLocalRank();
const int mpi_size = onnxruntime::training::MPIContext::GetInstance().GetWorldSize();
Expand All @@ -248,6 +249,7 @@ void NcclService::Initialize() {
if (mpi_rank == 0) NCCL_CALL_THROW(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
NCCL_CALL_THROW(ncclCommInitRank(&comm_, mpi_size, id, mpi_rank));
#endif // USE_MPI
}

void NcclService::Launch() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
--rocm_version=$(RocmVersion) \
--rocm_home /opt/rocm \
--nccl_home /opt/rocm \
--enable_nccl \
--update \
--build_dir /build \
--build \
Expand Down

0 comments on commit 3dfe4a5

Please sign in to comment.