diff --git a/src/collective/aggregator.cuh b/src/collective/aggregator.cuh new file mode 100644 index 000000000000..a87a968ab5c3 --- /dev/null +++ b/src/collective/aggregator.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2023 by XGBoost contributors + * + * Higher level functions built on top the Communicator API, taking care of behavioral differences + * between row-split vs column-split distributed training, and horizontal vs vertical federated + * learning. + */ +#pragma once +#include + +#include +#include +#include +#include + +#include "communicator-inl.cuh" + +namespace xgboost { +namespace collective { + +/** + * @brief Find the global sum of the given values across all workers. + * + * This only applies when the data is split row-wise (horizontally). When data is split + * column-wise (vertically), the original values are returned. + * + * @tparam T The type of the values. + * @param info MetaInfo about the DMatrix. + * @param device The device id. + * @param values Pointer to the inputs to sum. + * @param size Number of values to sum. + */ +template +void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) { + if (info.IsRowSplit()) { + collective::AllReduce(device, values, size); + } +} +} // namespace collective +} // namespace xgboost diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ecfc6c3ce468..b9a4424a5a68 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -418,7 +418,8 @@ void GPUHistEvaluator::EvaluateSplits( // Reduce to get the best candidate from all workers. dh::LaunchN(out_splits.size(), [world_size, all_candidates, out_splits] __device__(size_t i) { - for (auto rank = 0; rank < world_size; rank++) { + out_splits[i] = all_candidates[i]; + for (auto rank = 1; rank < world_size; rank++) { out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i]; } }); diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 489c8d6f7809..22eb7ab8185c 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -8,6 +8,7 @@ #include // uint32_t #include +#include "../../collective/aggregator.h" #include "../../common/deterministic.cuh" #include "../../common/device_helpers.cuh" #include "../../data/ellpack_page.cuh" @@ -52,7 +53,7 @@ struct Clip : public thrust::unary_function { * * to avoid outliers, as the full reduction is reproducible on GPU with reduction tree. */ -GradientQuantiser::GradientQuantiser(common::Span gpair) { +GradientQuantiser::GradientQuantiser(common::Span gpair, MetaInfo const& info) { using GradientSumT = GradientPairPrecise; using T = typename GradientSumT::ValueT; dh::XGBCachingDeviceAllocator alloc; @@ -64,11 +65,11 @@ GradientQuantiser::GradientQuantiser(common::Span gpair) { // Treat pair as array of 4 primitive types to allreduce using ReduceT = typename decltype(p.first)::ValueT; static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements."); - collective::Allreduce(reinterpret_cast(&p), 4); + collective::GlobalSum(info, reinterpret_cast(&p), 4); GradientPair positive_sum{p.first}, negative_sum{p.second}; std::size_t total_rows = gpair.size(); - collective::Allreduce(&total_rows, 1); + collective::GlobalSum(info, &total_rows, 1); auto histogram_rounding = GradientSumT{common::CreateRoundingFactor( diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index eb9008d48376..c693e2e62e52 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -39,7 +39,7 @@ private: GradientPairPrecise to_floating_point_; public: - explicit GradientQuantiser(common::Span gpair); + GradientQuantiser(common::Span gpair, MetaInfo const& info); XGBOOST_DEVICE GradientPairInt64 ToFixedPoint(GradientPair const& gpair) const { auto adjusted = GradientPairInt64(gpair.GetGrad() * to_fixed_point_.GetGrad(), gpair.GetHess() * to_fixed_point_.GetHess()); diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 215a0e49bde9..64ca540f667d 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -129,7 +129,7 @@ void SortPositionBatch(common::Span> d_batch_info, int batch_idx; std::size_t item_idx; AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx); - auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data); + auto op_res = op(ridx[item_idx], batch_idx, batch_info_itr[batch_idx].data); return IndexFlagTuple{static_cast(item_idx), op_res, batch_idx, op_res}; }); size_t temp_bytes = 0; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 0e42f1562f48..57eec0db8766 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -12,7 +12,8 @@ #include // for move #include -#include "../collective/communicator-inl.cuh" +#include "../collective/aggregator.h" +#include "../collective/aggregator.cuh" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/cuda_context.cuh" // CUDAContext @@ -161,6 +162,7 @@ struct GPUHistMakerDevice { GPUHistEvaluator evaluator_; Context const* ctx_; std::shared_ptr column_sampler_; + MetaInfo const& info_; public: EllpackPageImpl const* page{nullptr}; @@ -193,13 +195,14 @@ struct GPUHistMakerDevice { GPUHistMakerDevice(Context const* ctx, bool is_external_memory, common::Span _feature_types, bst_row_t _n_rows, TrainParam _param, std::shared_ptr column_sampler, - uint32_t n_features, BatchParam batch_param) + uint32_t n_features, BatchParam batch_param, MetaInfo const& info) : evaluator_{_param, n_features, ctx->gpu_id}, ctx_(ctx), feature_types{_feature_types}, param(std::move(_param)), column_sampler_(std::move(column_sampler)), - interaction_constraints(param, n_features) { + interaction_constraints(param, n_features), + info_{info} { sampler = std::make_unique(ctx, _n_rows, batch_param, param.subsample, param.sampling_method, is_external_memory); if (!param.monotone_constraints.empty()) { @@ -245,7 +248,7 @@ struct GPUHistMakerDevice { this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, dmat->Info().IsColumnSplit(), ctx_->gpu_id); - quantiser = std::make_unique(this->gpair); + quantiser = std::make_unique(this->gpair, dmat->Info()); row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner = std::make_unique(ctx_->gpu_id, sample.sample_rows); @@ -369,6 +372,66 @@ struct GPUHistMakerDevice { common::KCatBitField node_cats; }; + void UpdatePositionColumnSplit(EllpackDeviceAccessor d_matrix, + std::vector const& split_data, + std::vector const& nidx, + std::vector const& left_nidx, + std::vector const& right_nidx) { + auto const num_candidates = split_data.size(); + + using BitVector = LBitField64; + using BitType = BitVector::value_type; + auto const size = BitVector::ComputeStorageSize(d_matrix.n_rows * num_candidates); + dh::TemporaryArray decision_storage(size, 0); + dh::TemporaryArray missing_storage(size, 0); + BitVector decision_bits{dh::ToSpan(decision_storage)}; + BitVector missing_bits{dh::ToSpan(missing_storage)}; + + dh::TemporaryArray split_data_storage(num_candidates); + dh::safe_cuda(cudaMemcpyAsync(split_data_storage.data().get(), split_data.data(), + num_candidates * sizeof(NodeSplitData), cudaMemcpyDefault)); + auto d_split_data = dh::ToSpan(split_data_storage); + + dh::LaunchN(d_matrix.n_rows, [=] __device__(std::size_t ridx) mutable { + for (auto i = 0; i < num_candidates; i++) { + auto const& data = d_split_data[i]; + auto const cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); + if (isnan(cut_value)) { + missing_bits.Set(ridx * num_candidates + i); + } else { + bool go_left; + if (data.split_type == FeatureType::kCategorical) { + go_left = common::Decision(data.node_cats.Bits(), cut_value); + } else { + go_left = cut_value <= data.split_node.SplitCond(); + } + if (go_left) { + decision_bits.Set(ridx * num_candidates + i); + } + } + } + }); + + collective::AllReduce( + ctx_->gpu_id, decision_storage.data().get(), decision_storage.size()); + collective::AllReduce( + ctx_->gpu_id, missing_storage.data().get(), missing_storage.size()); + collective::Synchronize(ctx_->gpu_id); + + row_partitioner->UpdatePositionBatch( + nidx, left_nidx, right_nidx, split_data, + [=] __device__(bst_uint ridx, int split_index, NodeSplitData const& data) { + auto const index = ridx * num_candidates + split_index; + bool go_left; + if (missing_bits.Check(index)) { + go_left = data.split_node.DefaultLeft(); + } else { + go_left = decision_bits.Check(index); + } + return go_left; + }); + } + void UpdatePosition(std::vector const& candidates, RegTree* p_tree) { if (candidates.empty()) { return; @@ -392,9 +455,15 @@ struct GPUHistMakerDevice { } auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); + + if (info_.IsColumnSplit()) { + UpdatePositionColumnSplit(d_matrix, split_data, nidx, left_nidx, right_nidx); + return; + } + row_partitioner->UpdatePositionBatch( nidx, left_nidx, right_nidx, split_data, - [=] __device__(bst_uint ridx, const NodeSplitData& data) { + [=] __device__(bst_uint ridx, int split_index, const NodeSplitData& data) { // given a row index, returns the node id it belongs to float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); // Missing value @@ -544,9 +613,8 @@ struct GPUHistMakerDevice { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); using ReduceT = typename std::remove_pointer::type::ValueT; - collective::AllReduce( - ctx_->gpu_id, reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * 2 * num_histograms); + collective::GlobalSum(info_, ctx_->gpu_id, reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * 2 * num_histograms); monitor.Stop("AllReduce"); } @@ -663,8 +731,7 @@ struct GPUHistMakerDevice { dh::Reduce(ctx_->CUDACtx()->CTP(), gpair_it, gpair_it + gpair.size(), GradientPairInt64{}, thrust::plus{}); using ReduceT = typename decltype(root_sum_quantised)::ValueT; - collective::Allreduce( - reinterpret_cast(&root_sum_quantised), 2); + collective::GlobalSum(info_, reinterpret_cast(&root_sum_quantised), 2); hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); @@ -801,7 +868,7 @@ class GPUHistMaker : public TreeUpdater { info_->feature_types.SetDevice(ctx_->gpu_id); maker = std::make_unique( ctx_, !dmat->SingleColBlock(), info_->feature_types.ConstDeviceSpan(), info_->num_row_, - *param, column_sampler_, info_->num_col_, batch_param); + *param, column_sampler_, info_->num_col_, batch_param, dmat->Info()); p_last_fmat_ = dmat; initialised_ = true; @@ -915,7 +982,7 @@ class GPUGlobalApproxMaker : public TreeUpdater { auto batch = BatchParam{param->max_bin, hess, !task_->const_hess}; maker_ = std::make_unique( ctx_, !p_fmat->SingleColBlock(), info.feature_types.ConstDeviceSpan(), info.num_row_, - *param, column_sampler_, info.num_col_, batch); + *param, column_sampler_, info.num_col_, batch, p_fmat->Info()); std::size_t t_idx{0}; for (xgboost::RegTree* tree : trees) { diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index f74b7d3caef6..f4ed34bf0c0e 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -24,7 +24,7 @@ auto ZeroParam() { inline GradientQuantiser DummyRoundingFactor() { thrust::device_vector gpair(1); gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000 - return GradientQuantiser(dh::ToSpan(gpair)); + return {dh::ToSpan(gpair), MetaInfo()}; } thrust::device_vector ConvertToInteger(std::vector x) { diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 024a1e8d3a92..2eacd48e566f 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -39,7 +39,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { FeatureGroups feature_groups(page->Cuts(), page->is_dense, shm_size, sizeof(GradientPairInt64)); - auto quantiser = GradientQuantiser(gpair.DeviceSpan()); + auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo()); BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, d_histogram, quantiser); @@ -53,7 +53,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size) { dh::device_vector new_histogram(num_bins); auto d_new_histogram = dh::ToSpan(new_histogram); - auto quantiser = GradientQuantiser(gpair.DeviceSpan()); + auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo()); BuildGradientHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(0), feature_groups.DeviceAccessor(0), gpair.DeviceSpan(), ridx, d_new_histogram, quantiser); @@ -131,7 +131,7 @@ void TestGPUHistogramCategorical(size_t num_categories) { dh::device_vector cat_hist(num_categories); auto gpair = GenerateRandomGradients(kRows, 0, 2); gpair.SetDevice(0); - auto quantiser = GradientQuantiser(gpair.DeviceSpan()); + auto quantiser = GradientQuantiser(gpair.DeviceSpan(), MetaInfo()); /** * Generate hist with cat data. */ diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index 05098040024e..317728e01c0d 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -30,7 +30,7 @@ void TestUpdatePositionBatch() { std::vector extra_data = {0}; // Send the first five training instances to the right node // and the second 5 to the left node - rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int) { + rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int, int) { return ridx > 4; }); rows = rp.GetRowsHost(1); @@ -43,7 +43,7 @@ void TestUpdatePositionBatch() { } // Split the left node again - rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int) { + rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int, int) { return ridx < 7; }); EXPECT_EQ(rp.GetRows(3).size(), 2); @@ -57,7 +57,7 @@ void TestSortPositionBatch(const std::vector& ridx_in, const std::vector ridx_tmp(ridx_in.size()); thrust::device_vector counts(segments.size()); - auto op = [=] __device__(auto ridx, int data) { return ridx % 2 == 0; }; + auto op = [=] __device__(auto ridx, int split_index, int data) { return ridx % 2 == 0; }; std::vector op_data(segments.size()); std::vector> h_batch_info(segments.size()); dh::TemporaryArray> d_batch_info(segments.size()); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 50cdae7413bd..76734e526e1b 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -93,7 +93,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { Context ctx{MakeCUDACtx(0)}; auto cs = std::make_shared(0); GPUHistMakerDevice maker(&ctx, /*is_external_memory=*/false, {}, kNRows, param, cs, kNCols, - batch_param); + batch_param, MetaInfo()); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); HostDeviceVector gpair(kNRows); @@ -111,7 +111,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.hist.AllocateHistograms({0}); maker.gpair = gpair.DeviceSpan(); - maker.quantiser = std::make_unique(maker.gpair); + maker.quantiser = std::make_unique(maker.gpair, MetaInfo()); maker.page = page.get(); maker.InitFeatureGroupsOnce(); @@ -165,7 +165,7 @@ HistogramCutsWrapper GetHostCutMatrix () { inline GradientQuantiser DummyRoundingFactor() { thrust::device_vector gpair(1); gpair[0] = {1000.f, 1000.f}; // Tests should not exceed sum of 1000 - return GradientQuantiser(dh::ToSpan(gpair)); + return {dh::ToSpan(gpair), MetaInfo()}; } void TestHistogramIndexImpl() { @@ -426,4 +426,54 @@ TEST(GpuHist, MaxDepth) { ASSERT_THROW({learner->UpdateOneIter(0, p_mat);}, dmlc::Error); } + +namespace { +RegTree GetUpdatedTree(Context const* ctx, DMatrix* dmat) { + ObjInfo task{ObjInfo::kRegression}; + GPUHistMaker hist_maker{ctx, &task}; + hist_maker.Configure(Args{}); + + TrainParam param; + param.UpdateAllowUnknown(Args{}); + + linalg::Matrix gpair({dmat->Info().num_row_}, ctx->Ordinal()); + gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_)); + + std::vector> position(1); + RegTree tree; + hist_maker.Update(¶m, &gpair, dmat, common::Span>{position}, + {&tree}); + return tree; +} + +void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) { + Context ctx(MakeCUDACtx(GPUIDX)); + + auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true); + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + std::unique_ptr sliced{Xy->SliceCol(world_size, rank)}; + + RegTree tree = GetUpdatedTree(&ctx, sliced.get()); + + Json json{Object{}}; + tree.SaveModel(&json); + Json expected_json{Object{}}; + expected_tree.SaveModel(&expected_json); + ASSERT_EQ(json, expected_json); +} +} // anonymous namespace + +class MGPUHistTest : public BaseMGPUTest {}; + +TEST_F(MGPUHistTest, GPUHistColumnSplit) { + auto constexpr kRows = 32; + auto constexpr kCols = 16; + + Context ctx(MakeCUDACtx(0)); + auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true); + RegTree expected_tree = GetUpdatedTree(&ctx, dmat.get()); + + DoTest(VerifyColumnSplit, kRows, kCols, expected_tree); +} } // namespace xgboost::tree