Skip to content

Commit

Permalink
Support column split in GPU evaluate splits (#9511)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Aug 23, 2023
1 parent 8c10af4 commit 6103dca
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 113 deletions.
14 changes: 14 additions & 0 deletions src/collective/communicator-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}

/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
}

/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
Expand Down
11 changes: 11 additions & 0 deletions src/collective/device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ class DeviceCommunicator {
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;

/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;

/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.
Expand Down
18 changes: 16 additions & 2 deletions src/collective/device_communicator_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,26 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {

dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
host_buffer_.resize(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size * world_size_);
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
cudaMemcpyDefault));
Allgather(host_buffer_.data(), host_buffer_.size());
dh::safe_cuda(
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (world_size_ == 1) {
Expand All @@ -49,7 +63,7 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);

host_buffer_.reserve(total_bytes);
host_buffer_.resize(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
Expand Down
11 changes: 11 additions & 0 deletions src/collective/nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,17 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
allreduce_calls_ += 1;
}

void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
if (world_size_ == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream()));
}

void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Expand Down
1 change: 1 addition & 0 deletions src/collective/nccl_device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override;
void Synchronize() override;
Expand Down
19 changes: 18 additions & 1 deletion src/tree/gpu_hist/evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include <vector>
#include <limits>

#include "../../collective/communicator-inl.cuh"
#include "../../common/categorical.h"
#include "../../common/device_helpers.cuh"
#include "../../data/ellpack_page.cuh"
#include "evaluate_splits.cuh"
#include "expand_entry.cuh"
Expand Down Expand Up @@ -409,6 +409,23 @@ void GPUHistEvaluator::EvaluateSplits(
this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs,
evaluator, out_splits);

if (is_column_split_) {
// With column-wise data split, we gather the split candidates from all the workers and find the
// global best candidates.
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
collective::AllGather(device_, out_splits.data(), all_candidates.data(),
out_splits.size() * sizeof(DeviceSplitCandidate));

// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) {
for (auto rank = 0; rank < world_size; rank++) {
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
}
});
}

auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
auto d_entries = out_entries;
auto device_cats_accessor = this->DeviceCatStorage(nidx);
Expand Down
6 changes: 5 additions & 1 deletion src/tree/gpu_hist/evaluate_splits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class GPUHistEvaluator {
// Number of elements of categorical storage type
// needed to hold categoricals for a single mode
std::size_t node_categorical_storage_size_ = 0;
// Is the data split column-wise?
bool is_column_split_ = false;
int32_t device_;

// Copy the categories from device to host asynchronously.
void CopyToHost( const std::vector<bst_node_t>& nidx);
Expand Down Expand Up @@ -136,7 +139,8 @@ class GPUHistEvaluator {
* \brief Reset the evaluator, should be called before any use.
*/
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param, int32_t device);
bst_feature_t n_features, TrainParam const &param, bool is_column_split,
int32_t device);

/**
* \brief Get host category storage for nidx. Different from the internal version, this
Expand Down
9 changes: 5 additions & 4 deletions src/tree/gpu_hist/evaluator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

namespace xgboost {
namespace tree {
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param,
int32_t device) {
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param,
bool is_column_split, int32_t device) {
param_ = param;
tree_evaluator_ = TreeEvaluator{param, n_features, device};
has_categoricals_ = cuts.HasCategorical();
Expand Down Expand Up @@ -65,6 +64,8 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
return fidx;
});
}
is_column_split_ = is_column_split;
device_ = device;
}

common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
Expand Down
3 changes: 2 additions & 1 deletion src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ struct GPUHistMakerDevice {
page = sample.page;
gpair = sample.gpair;

this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id);
this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param,
dmat->Info().IsColumnSplit(), ctx_->gpu_id);

quantiser.reset(new GradientQuantiser(this->gpair));

Expand Down
24 changes: 23 additions & 1 deletion tests/cpp/plugin/test_federated_adapter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/device_communicator_adapter.cuh"
#include "./helpers.h"
#include "../helpers.h"
#include "./helpers.h"

namespace xgboost::collective {

Expand Down Expand Up @@ -45,6 +45,28 @@ TEST_F(FederatedAdapterTest, MGPUAllReduceSum) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllReduceSum);
}

namespace {
void VerifyAllGather() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
auto const device = GPUIDX;
common::SetDevice(device);
thrust::device_vector<double> send_buffer(1, rank);
thrust::device_vector<double> receive_buffer(world_size, 0);
collective::AllGather(device, send_buffer.data().get(), receive_buffer.data().get(),
sizeof(double));
thrust::host_vector<double> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), world_size);
for (auto i = 0; i < world_size; i++) {
EXPECT_EQ(host_buffer[i], i);
}
}
} // anonymous namespace

TEST_F(FederatedAdapterTest, MGPUAllGather) {
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyAllGather);
}

namespace {
void VerifyAllGatherV() {
auto const world_size = collective::GetWorldSize();
Expand Down
Loading

0 comments on commit 6103dca

Please sign in to comment.