diff --git a/README.md b/README.md index 9483fb9..0506e4e 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,10 @@ $ cmake -DCUDECOMP_BUILD_EXTRAS=1 -DCUDECOMP_ENABLE_NVSHMEM=1 .. ### Dependencies We strongly recommend building this library using NVHPC SDK compilers and libraries, as the SDK contains all required dependencies for this library and is the focus of our testing. Fortran features are only supported using NVHPC SDK compilers. -One exception is NVSHMEM, which uses a bootstrapping layer that depends on your MPI installation. The NVSHMEM library packaged within NVHPC +One exception is cuDecomp builds using NVSHMEM versions older than v3.0, which require the use of a bootstrapping layer that depends on your MPI distribution. The NVSHMEM library packaged within NVHPC SDK supports OpenMPI only. If you require usage of a different MPI implementation (e.g. Spectrum MPI or Cray MPICH), you need to either build -NVSHMEM against your desired MPI implementation, or build a custom MPI bootstrap layer. Please refer to this [NVSHMEM documentation section](https://docs.nvidia.com/hpc-sdk/nvshmem/install-guide/index.html#use-nvshmem-mpi) for more details. +NVSHMEM against your desired MPI implementation, or build a custom MPI bootstrap layer separately. Please refer to this [NVSHMEM documentation section](https://docs.nvidia.com/hpc-sdk/nvshmem/install-guide/index.html#use-nvshmem-mpi) for more details. +For cuDecomp builds using NVSHMEM v3.0+, this additional MPI boostrapping layer is no longer required. Additionally, this library utilizes CUDA-aware MPI and is only compatible with MPI libraries with these features enabled. diff --git a/src/cudecomp.cc b/src/cudecomp.cc index ddccb12..89b452a 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -71,6 +71,26 @@ static ncclComm_t ncclCommFromMPIComm(MPI_Comm mpi_comm) { return nccl_comm; } +static void initNvshmemFromMPIComm(MPI_Comm mpi_comm) { + int rank, nranks; + CHECK_MPI(MPI_Comm_rank(mpi_comm, &rank)); + CHECK_MPI(MPI_Comm_size(mpi_comm, &nranks)); + + nvshmemx_init_attr_t attr; +#if NVSHMEM_VENDOR_MAJOR_VERSION >= 3 + nvshmemx_uniqueid_t id; + if (rank == 0) nvshmemx_get_uniqueid(&id); + CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, mpi_comm)); + CHECK_MPI(MPI_Barrier(mpi_comm)); + nvshmemx_set_attr_uniqueid_args(rank, nranks, &id, &attr); + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); +#else + attr.mpi_comm = &mpi_comm; + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr); +#endif + +} + static void checkTransposeCommBackend(cudecompTransposeCommBackend_t comm_backend) { switch (comm_backend) { case CUDECOMP_TRANSPOSE_COMM_NCCL: @@ -346,9 +366,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes #ifdef ENABLE_NVSHMEM if (!handle->nvshmem_initialized) { inspectNvshmemEnvVars(handle); - nvshmemx_init_attr_t attr; - attr.mpi_comm = &handle->mpi_comm; - nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr); + initNvshmemFromMPIComm(handle->mpi_comm); handle->nvshmem_initialized = true; handle->nvshmem_allocation_size = 0; }