Skip to content

Commit

Permalink
mpi: in-place alltoall
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed May 13, 2024
1 parent 5fca161 commit d3913cc
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/impl/KokkosComm_alltoall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,38 @@ void alltoall(const ExecSpace &space, const SendView &sv,

Kokkos::Tools::popRegion();
}
} // namespace KokkosComm::Impl

// in-place alltoall
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
void alltoall(const ExecSpace &space, const RecvView &rv,
const size_t recvCount, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");

using RT = KokkosComm::Traits<RecvView>;
using RecvScalar = typename RecvView::value_type;

static_assert(RT::rank() <= 1,
"alltoall for RecvView::rank > 1 not supported");

if (KokkosComm::PackTraits<RecvView>::needs_pack(rv)) {
throw std::runtime_error(
"alltoall for non-contiguous views not implemented");
} else {
int size;
MPI_Comm_size(comm, &size);

if (recvCount * size > RT::extent(rv, 0)) {
std::stringstream ss;
ss << "alltoall recvCount * communicator size (" << recvCount << " * "
<< size << ") is greater than recv view size";
throw std::runtime_error(ss.str());
}

MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/,
RT::data_handle(rv), recvCount, mpi_type_v<RecvScalar>, comm);
}

Kokkos::Tools::popRegion();
}

} // namespace KokkosComm::Impl
31 changes: 31 additions & 0 deletions unit_tests/test_alltoall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,35 @@ TYPED_TEST(Alltoall, 1D_contig) {
EXPECT_EQ(errs, 0);
}

TYPED_TEST(Alltoall, 1D_inplace_contig) {
using TestScalar = typename TestFixture::Scalar;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);

const int nContrib = 10;

Kokkos::View<TestScalar *> rv("rv", size * nContrib);

// fill send buffer
Kokkos::parallel_for(
rv.extent(0), KOKKOS_LAMBDA(const int i) { rv(i) = rank + i; });

KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), rv, nContrib,
MPI_COMM_WORLD);

int errs;
Kokkos::parallel_reduce(
rv.extent(0),
KOKKOS_LAMBDA(const int &i, int &lsum) {
const int src = i / nContrib; // who sent this data
const int j =
rank * nContrib + (i % nContrib); // what index i was at the source
lsum += rv(i) != src + j;
},
errs);
EXPECT_EQ(errs, 0);
}

} // namespace

0 comments on commit d3913cc

Please sign in to comment.