diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6f54943f09afe..cadb06bb38707 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -94,30 +94,18 @@ set(contrib_ops_excluded_files "bert/group_query_attention.cc" "bert/group_query_attention_impl.h" "bert/group_query_attention_impl.cu" + "collective/distributed_*" + "collective/shard*" ) -if (NOT onnxruntime_ENABLE_ATEN) - list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") -endif() if (NOT onnxruntime_USE_NCCL) # Those are string patterns to exclude. Do NOT use stars such as # collective/*.cc or *.h. list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") - list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h") - list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc") - list(APPEND contrib_ops_excluded_files "collective/sharding.cc") - list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_unsqueeze.cc") - list(APPEND contrib_ops_excluded_files "collective/distributed_squeeze.cc") -else() - # moe not supported for ROCm EP - list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h") - list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc") +endif() + +if (NOT onnxruntime_ENABLE_ATEN) + list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 382a3951f3a83..e19a976f3141c 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -151,7 +151,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, ShrunkenGather); #endif -#if defined(USE_MPI) && defined(ORT_USE_NCCL) +#ifdef ORT_USE_NCCL class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, AllToAll); @@ -311,7 +311,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif -#if defined(USE_MPI) && defined(ORT_USE_NCCL) +#ifdef ORT_USE_NCCL BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc index f604e4c4aaf3e..c642a87e22de6 100644 --- a/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc +++ b/orttraining/orttraining/training_ops/cuda/communication/nccl_service.cc @@ -233,6 +233,7 @@ void NcclService::Initialize() { // CPUs // Other devices +#ifdef USE_MPI const int mpi_rank = onnxruntime::training::MPIContext::GetInstance().GetWorldRank(); const int mpi_local_rank = onnxruntime::training::MPIContext::GetInstance().GetLocalRank(); const int mpi_size = onnxruntime::training::MPIContext::GetInstance().GetWorldSize(); @@ -248,6 +249,7 @@ void NcclService::Initialize() { if (mpi_rank == 0) NCCL_CALL_THROW(ncclGetUniqueId(&id)); MPI_CHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCL_CALL_THROW(ncclCommInitRank(&comm_, mpi_size, id, mpi_rank)); +#endif // USE_MPI } void NcclService::Launch() { diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml index 9cf7a3fb42397..8b58d958ba899 100644 --- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml @@ -109,6 +109,7 @@ jobs: --rocm_version=$(RocmVersion) \ --rocm_home /opt/rocm \ --nccl_home /opt/rocm \ + --enable_nccl \ --update \ --build_dir /build \ --build \