Skip to content

Commit

Permalink
Fix NCCL test hang (#9367)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Jul 7, 2023
1 parent 41c6813 commit 15ca12a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
16 changes: 12 additions & 4 deletions src/collective/communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
break;
case CommunicatorType::kFederated:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemory:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
Expand Down
10 changes: 8 additions & 2 deletions src/collective/nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
namespace xgboost {
namespace collective {

NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
Expand Down Expand Up @@ -140,6 +143,9 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, cuda_stream_));
if (needs_sync_) {
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}

// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
Expand Down
16 changes: 15 additions & 1 deletion src/collective/nccl_device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,20 @@ namespace collective {

class NcclDeviceCommunicator : public DeviceCommunicator {
public:
explicit NcclDeviceCommunicator(int device_ordinal);
/**
* @brief Construct a new NCCL communicator.
* @param device_ordinal The GPU device id.
* @param needs_sync Whether extra CUDA stream synchronization is needed.
*
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
*
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
Expand Down Expand Up @@ -60,6 +73,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
Operation op);

int const device_ordinal_;
bool const needs_sync_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/collective/test_nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace xgboost {
namespace collective {

TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1}; };
auto construct = []() { NcclDeviceCommunicator comm{-1, false}; };
EXPECT_THROW(construct(), dmlc::Error);
}

Expand Down

0 comments on commit 15ca12a

Please sign in to comment.