From 818f3ada0bde526ec1dde6545cbf06868928b8e7 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 20 Jun 2023 11:20:06 -0700 Subject: [PATCH 1/9] minor optimization --- src/collective/nccl_device_communicator.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu index 6599d4b5a30e..dcec2c988738 100644 --- a/src/collective/nccl_device_communicator.cu +++ b/src/collective/nccl_device_communicator.cu @@ -129,10 +129,11 @@ template void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size, std::size_t size, cudaStream_t stream) { dh::LaunchN(size, stream, [=] __device__(std::size_t idx) { - out_buffer[idx] = device_buffer[idx]; + auto result = device_buffer[idx]; for (auto rank = 1; rank < world_size; rank++) { - out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]); + result = func(result, device_buffer[rank * size + idx]); } + out_buffer[idx] = result; }); } } // anonymous namespace From 647582ee1ecd47ae19c6ce404e73912cd53dd79f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 20 Jun 2023 16:30:11 -0700 Subject: [PATCH 2/9] add basic test --- tests/cpp/predictor/test_gpu_predictor.cu | 66 +++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 4cf2970c1752..4285263852f1 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -57,6 +57,72 @@ TEST(GPUPredictor, Basic) { } } +namespace { +void VerifyBasicColumnSplit(std::array, 32> const& expected_result) { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + + auto lparam = MakeCUDACtx(rank); + std::unique_ptr predictor = + std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); + predictor->Configure({}); + + for (size_t i = 1; i < 33; i *= 2) { + size_t n_row = i, n_col = i; + auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix(); + std::unique_ptr sliced{dmat->SliceCol(world_size, rank)}; + + Context ctx; + ctx.gpu_id = rank; + LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)}; + gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); + + // Test predict batch + PredictionCacheEntry out_predictions; + + predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); + predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + + std::vector& out_predictions_h = out_predictions.predictions.HostVector(); + EXPECT_EQ(out_predictions_h, expected_result[i - 1]); + } +} +} // anonymous namespace + +TEST(GPUPredictor, MGPUBasicColumnSplit) { + auto const n_gpus = common::AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUIBasicColumnSplit test with # GPUs = " << n_gpus; + } + + auto lparam = MakeCUDACtx(0); + std::unique_ptr predictor = + std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); + predictor->Configure({}); + + std::array, 32> result{}; + for (size_t i = 1; i < 33; i *= 2) { + size_t n_row = i, n_col = i; + auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix(); + + Context ctx; + ctx.gpu_id = 0; + LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)}; + gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); + + // Test predict batch + PredictionCacheEntry out_predictions; + + predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); + predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + + std::vector& out_predictions_h = out_predictions.predictions.HostVector(); + result[i - 1] = out_predictions_h; + } + + RunWithInMemoryCommunicator(n_gpus, VerifyBasicColumnSplit, result); +} + TEST(GPUPredictor, EllpackBasic) { size_t constexpr kCols {8}; for (size_t bins = 2; bins < 258; bins += 16) { From d6f062be03a39cfc402b95e1126e469abb1751b4 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 26 Jun 2023 13:19:54 -0700 Subject: [PATCH 3/9] add GetDecision function --- src/predictor/cpu_predictor.cc | 18 +++++++----------- src/predictor/predict_fn.h | 20 +++++++++++++------- tests/cpp/predictor/test_gpu_predictor.cu | 4 ++-- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 26b5a85b67d9..b4184f6fb5d0 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -467,7 +467,6 @@ class ColumnSplitHelper { void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) { auto const &tree = *model_.trees[tree_id]; auto const &cats = tree.GetCategoriesMatrix(); - auto const has_categorical = tree.HasCategoricalSplit(); bst_node_t n_nodes = tree.GetNodes().size(); for (bst_node_t nid = 0; nid < n_nodes; nid++) { @@ -484,16 +483,13 @@ class ColumnSplitHelper { } auto const fvalue = feat.GetFvalue(split_index); - if (has_categorical && common::IsCat(cats.split_type, nid)) { - auto const node_categories = - cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); - if (!common::Decision(node_categories, fvalue)) { - decision_bits_.Set(bit_index); - } - continue; + auto decision = false; + if (tree.HasCategoricalSplit()) { + decision = GetDecision(node, nid, fvalue, cats); + } else { + decision = GetDecision(node, nid, fvalue, cats); } - - if (fvalue >= node.SplitCond()) { + if (decision) { decision_bits_.Set(bit_index); } } @@ -511,7 +507,7 @@ class ColumnSplitHelper { if (missing_bits_.Check(bit_index)) { return node.DefaultChild(); } else { - return node.LeftChild() + decision_bits_.Check(bit_index); + return node.LeftChild() + !decision_bits_.Check(bit_index); } } diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index dbaf4a75e060..a8b4c45c595f 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -7,6 +7,18 @@ #include "xgboost/tree_model.h" namespace xgboost::predictor { +/** @brief Whether it should traverse to the left branch of a tree. */ +template +inline XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue, + RegTree::CategoricalSplitMatrix const &cats) { + if (has_categorical && common::IsCat(cats.split_type, nid)) { + auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); + return common::Decision(node_categories, fvalue); + } else { + return fvalue < node.SplitCond(); + } +} + template inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, bool is_missing, @@ -14,13 +26,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs if (has_missing && is_missing) { return node.DefaultChild(); } else { - if (has_categorical && common::IsCat(cats.split_type, nid)) { - auto node_categories = - cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); - return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild(); - } else { - return node.LeftChild() + !(fvalue < node.SplitCond()); - } + return node.LeftChild() + !GetDecision(node, nid, fvalue, cats); } } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 4285263852f1..14e55c1578ad 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -80,8 +80,8 @@ void VerifyBasicColumnSplit(std::array, 32> const& expected_r // Test predict batch PredictionCacheEntry out_predictions; - predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); + predictor->InitOutPredictions(sliced->Info(), &out_predictions.predictions, model); + predictor->PredictBatch(sliced.get(), &out_predictions, model, 0); std::vector& out_predictions_h = out_predictions.predictions.HostVector(); EXPECT_EQ(out_predictions_h, expected_result[i - 1]); From 9dc920a453e4af683f8a612bfe4e3a8a2a32e8ec Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 27 Jun 2023 11:59:31 -0700 Subject: [PATCH 4/9] support colsplit in gpu predictor --- src/predictor/cpu_predictor.cc | 9 +- src/predictor/gpu_predictor.cu | 207 ++++++++++++++++++++++++++++++++- 2 files changed, 206 insertions(+), 10 deletions(-) diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index b4184f6fb5d0..56362a11251c 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -483,12 +483,9 @@ class ColumnSplitHelper { } auto const fvalue = feat.GetFvalue(split_index); - auto decision = false; - if (tree.HasCategoricalSplit()) { - decision = GetDecision(node, nid, fvalue, cats); - } else { - decision = GetDecision(node, nid, fvalue, cats); - } + auto const decision = tree.HasCategoricalSplit() + ? GetDecision(node, nid, fvalue, cats) + : GetDecision(node, nid, fvalue, cats); if (decision) { decision_bits_.Set(bit_index); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 98e38068239f..4438391315e8 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -11,6 +11,7 @@ #include // for any, any_cast #include +#include "../collective/communicator-inl.cuh" #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/common.h" @@ -110,13 +111,11 @@ struct SparsePageLoader { bool use_shared; SparsePageView data; float* smem; - size_t entry_start; __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, bst_row_t num_rows, size_t entry_start, float) : use_shared(use_shared), - data(data), - entry_start(entry_start) { + data(data) { extern __shared__ float _smem[]; smem = _smem; // Copy instances @@ -622,6 +621,198 @@ size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) { } return shared_memory_bytes; } + +using BitVector = LBitField64; + +__global__ void MaskBitVectorKernel( + SparsePageView data, common::Span d_nodes, + common::Span d_tree_segments, common::Span d_tree_group, + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories, BitVector decision_bits, BitVector missing_bits, + std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows, + std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) { + auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx >= num_rows) { + return; + } + SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing); + + std::size_t tree_offset = 0; + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + TreeView d_tree{tree_begin, tree_idx, d_nodes, + d_tree_segments, d_tree_split_types, d_cat_tree_segments, + d_cat_node_segments, d_categories}; + auto const tree_nodes = d_tree.d_tree.size(); + for (auto nid = 0; nid < tree_nodes; nid++) { + auto const& node = d_tree.d_tree[nid]; + if (node.IsDeleted() || node.IsLeaf()) { + continue; + } + auto const fvalue = loader.GetElement(row_idx, node.SplitIndex()); + auto const is_missing = common::CheckNAN(fvalue); + auto const bit_index = row_idx * num_nodes + tree_offset + nid; + if (is_missing) { + missing_bits.Set(bit_index); + } else { + auto const decision = d_tree.HasCategoricalSplit() + ? GetDecision(node, nid, fvalue, d_tree.cats) + : GetDecision(node, nid, fvalue, d_tree.cats); + if (decision) { + decision_bits.Set(bit_index); + } + } + } + tree_offset += tree_nodes; + } +} + +__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree, + BitVector const& decision_bits, + BitVector const& missing_bits, std::size_t num_nodes, + std::size_t tree_offset) { + bst_node_t nidx = 0; + RegTree::Node n = tree.d_tree[nidx]; + while (!n.IsLeaf()) { + auto const bit_index = ridx * num_nodes + tree_offset + nidx; + if (missing_bits.Check(bit_index)) { + nidx = n.DefaultChild(); + } else { + nidx = n.LeftChild() + !decision_bits.Check(bit_index); + } + n = tree.d_tree[nidx]; + } + return tree.d_tree[nidx].LeafValue(); +} + +__global__ void PredictByBitVectorKernel( + common::Span d_nodes, common::Span d_out_predictions, + common::Span d_tree_segments, common::Span d_tree_group, + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories, BitVector decision_bits, BitVector missing_bits, + std::size_t tree_begin, std::size_t tree_end, std::size_t num_rows, std::size_t num_nodes, + std::uint32_t num_group) { + auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (row_idx >= num_rows) { + return; + } + + std::size_t tree_offset = 0; + if (num_group == 1) { + float sum = 0; + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + TreeView d_tree{tree_begin, tree_idx, d_nodes, + d_tree_segments, d_tree_split_types, d_cat_tree_segments, + d_cat_node_segments, d_categories}; + sum += GetLeafWeightByBitVector(row_idx, d_tree, decision_bits, missing_bits, num_nodes, + tree_offset); + tree_offset += d_tree.d_tree.size(); + } + d_out_predictions[row_idx] += sum; + } else { + for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + auto const tree_group = d_tree_group[tree_idx]; + TreeView d_tree{tree_begin, tree_idx, d_nodes, + d_tree_segments, d_tree_split_types, d_cat_tree_segments, + d_cat_node_segments, d_categories}; + bst_uint out_prediction_idx = row_idx * num_group + tree_group; + d_out_predictions[out_prediction_idx] += GetLeafWeightByBitVector( + row_idx, d_tree, decision_bits, missing_bits, num_nodes, tree_offset); + tree_offset += d_tree.d_tree.size(); + } + } +} + +class ColumnSplitHelper { + public: + explicit ColumnSplitHelper(std::int32_t gpu_id) : gpu_id_{gpu_id} {} + + void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, + gbm::GBTreeModel const& model, DeviceModel const& d_model) const { + CHECK(dmat->PageExists()) << "Column split for external memory is not support."; + PredictDMatrix(dmat, out_preds, d_model, model.learner_model_param->num_feature, + model.learner_model_param->num_output_group); + } + + private: + using BitType = BitVector::value_type; + + void PredictDMatrix(DMatrix* dmat, HostDeviceVector* out_preds, DeviceModel const& model, + bst_feature_t num_features, std::uint32_t num_group) const { + dh::safe_cuda(cudaSetDevice(gpu_id_)); + dh::caching_device_vector decision_storage{}; + dh::caching_device_vector missing_storage{}; + + auto constexpr kBlockThreads = 128; + auto const max_shared_memory_bytes = dh::MaxSharedMemory(gpu_id_); + auto const shared_memory_bytes = + SharedMemoryBytes(num_features, max_shared_memory_bytes); + auto const use_shared = shared_memory_bytes != 0; + + auto const num_nodes = model.nodes.Size(); + std::size_t batch_offset = 0; + for (auto const& batch : dmat->GetBatches()) { + auto const num_rows = batch.Size(); + ResizeBitVectors(&decision_storage, &missing_storage, num_rows * num_nodes); + BitVector decision_bits{dh::ToSpan(decision_storage)}; + BitVector missing_bits{dh::ToSpan(missing_storage)}; + + batch.offset.SetDevice(gpu_id_); + batch.data.SetDevice(gpu_id_); + std::size_t entry_start = 0; + SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); + + auto const grid = static_cast(common::DivRoundUp(num_rows, kBlockThreads)); + dh::LaunchKernel{grid, kBlockThreads, shared_memory_bytes}( + MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(), + model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), + model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), + model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(), + decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows, + entry_start, num_nodes, use_shared, nan("")); + + AllReduceBitVectors(&decision_storage, &missing_storage); + + dh::LaunchKernel{grid, kBlockThreads}( + PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(), + out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), + model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), + model.categories_tree_segments.ConstDeviceSpan(), + model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(), + decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_rows, num_nodes, + num_group); + + batch_offset += batch.Size() * num_group; + } + } + + void AllReduceBitVectors(dh::caching_device_vector* decision_storage, + dh::caching_device_vector* missing_storage) const { + collective::AllReduce( + gpu_id_, decision_storage->data().get(), decision_storage->size()); + collective::AllReduce( // Align to make it easier to read. + gpu_id_, missing_storage->data().get(), missing_storage->size()); + } + + static void ResizeBitVectors(dh::caching_device_vector* decision_storage, + dh::caching_device_vector* missing_storage, + std::size_t total_bits) { + auto const size = BitVector::ComputeStorageSize(total_bits); + if (decision_storage->size() < size) { + decision_storage->resize(size); + } + decision_storage->clear(); + if (missing_storage->size() < size) { + missing_storage->resize(size); + } + missing_storage->clear(); + } + + std::int32_t const gpu_id_; +}; } // anonymous namespace class GPUPredictor : public xgboost::Predictor { @@ -697,6 +888,11 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel d_model; d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id); + if (dmat->Info().IsColumnSplit()) { + column_split_helper_.PredictBatch(dmat, out_preds, model, d_model); + return; + } + if (dmat->PageExists()) { size_t batch_offset = 0; for (auto &batch : dmat->GetBatches()) { @@ -720,7 +916,8 @@ class GPUPredictor : public xgboost::Predictor { } public: - explicit GPUPredictor(Context const* ctx) : Predictor::Predictor{ctx} {} + explicit GPUPredictor(Context const* ctx) + : Predictor::Predictor{ctx}, column_split_helper_{ctx->gpu_id} {} ~GPUPredictor() override { if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) { @@ -1019,6 +1216,8 @@ class GPUPredictor : public xgboost::Predictor { } return 0; } + + ColumnSplitHelper column_split_helper_; }; XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") From 3ea4835633bdec0a2d1fee7f8b3894bd55d53be6 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 27 Jun 2023 12:36:06 -0700 Subject: [PATCH 5/9] fix lint error --- src/predictor/gpu_predictor.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 4438391315e8..8b2cf8674aed 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -766,7 +766,7 @@ class ColumnSplitHelper { SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); auto const grid = static_cast(common::DivRoundUp(num_rows, kBlockThreads)); - dh::LaunchKernel{grid, kBlockThreads, shared_memory_bytes}( + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), @@ -776,7 +776,7 @@ class ColumnSplitHelper { AllReduceBitVectors(&decision_storage, &missing_storage); - dh::LaunchKernel{grid, kBlockThreads}( + dh::LaunchKernel {grid, kBlockThreads} ( PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), @@ -795,6 +795,7 @@ class ColumnSplitHelper { gpu_id_, decision_storage->data().get(), decision_storage->size()); collective::AllReduce( // Align to make it easier to read. gpu_id_, missing_storage->data().get(), missing_storage->size()); + collective::Synchronize(gpu_id_); } static void ResizeBitVectors(dh::caching_device_vector* decision_storage, From 97a895a11707372f3f9d85326fc1a654202955ff Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 28 Jun 2023 15:37:05 -0700 Subject: [PATCH 6/9] address review comments --- src/predictor/gpu_predictor.cu | 20 ++++++++++---------- src/predictor/predict_fn.h | 4 ++-- tests/cpp/predictor/test_gpu_predictor.cu | 12 ++++-------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 8b2cf8674aed..9367b14ca94c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -728,7 +728,7 @@ __global__ void PredictByBitVectorKernel( class ColumnSplitHelper { public: - explicit ColumnSplitHelper(std::int32_t gpu_id) : gpu_id_{gpu_id} {} + explicit ColumnSplitHelper(Context const* ctx) : ctx_{ctx} {} void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, gbm::GBTreeModel const& model, DeviceModel const& d_model) const { @@ -742,12 +742,12 @@ class ColumnSplitHelper { void PredictDMatrix(DMatrix* dmat, HostDeviceVector* out_preds, DeviceModel const& model, bst_feature_t num_features, std::uint32_t num_group) const { - dh::safe_cuda(cudaSetDevice(gpu_id_)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); dh::caching_device_vector decision_storage{}; dh::caching_device_vector missing_storage{}; auto constexpr kBlockThreads = 128; - auto const max_shared_memory_bytes = dh::MaxSharedMemory(gpu_id_); + auto const max_shared_memory_bytes = dh::MaxSharedMemory(ctx_->gpu_id); auto const shared_memory_bytes = SharedMemoryBytes(num_features, max_shared_memory_bytes); auto const use_shared = shared_memory_bytes != 0; @@ -760,8 +760,8 @@ class ColumnSplitHelper { BitVector decision_bits{dh::ToSpan(decision_storage)}; BitVector missing_bits{dh::ToSpan(missing_storage)}; - batch.offset.SetDevice(gpu_id_); - batch.data.SetDevice(gpu_id_); + batch.offset.SetDevice(ctx_->gpu_id); + batch.data.SetDevice(ctx_->gpu_id); std::size_t entry_start = 0; SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); @@ -792,10 +792,10 @@ class ColumnSplitHelper { void AllReduceBitVectors(dh::caching_device_vector* decision_storage, dh::caching_device_vector* missing_storage) const { collective::AllReduce( - gpu_id_, decision_storage->data().get(), decision_storage->size()); + ctx_->gpu_id, decision_storage->data().get(), decision_storage->size()); collective::AllReduce( // Align to make it easier to read. - gpu_id_, missing_storage->data().get(), missing_storage->size()); - collective::Synchronize(gpu_id_); + ctx_->gpu_id, missing_storage->data().get(), missing_storage->size()); + collective::Synchronize(ctx_->gpu_id); } static void ResizeBitVectors(dh::caching_device_vector* decision_storage, @@ -812,7 +812,7 @@ class ColumnSplitHelper { missing_storage->clear(); } - std::int32_t const gpu_id_; + Context const* ctx_; }; } // anonymous namespace @@ -918,7 +918,7 @@ class GPUPredictor : public xgboost::Predictor { public: explicit GPUPredictor(Context const* ctx) - : Predictor::Predictor{ctx}, column_split_helper_{ctx->gpu_id} {} + : Predictor::Predictor{ctx}, column_split_helper_{ctx} {} ~GPUPredictor() override { if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) { diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index a8b4c45c595f..044832010ccb 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -9,8 +9,8 @@ namespace xgboost::predictor { /** @brief Whether it should traverse to the left branch of a tree. */ template -inline XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue, - RegTree::CategoricalSplitMatrix const &cats) { +XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue, + RegTree::CategoricalSplitMatrix const &cats) { if (has_categorical && common::IsCat(cats.split_type, nid)) { auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size); return common::Decision(node_categories, fvalue); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 14e55c1578ad..6911824a98e6 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -62,9 +62,9 @@ void VerifyBasicColumnSplit(std::array, 32> const& expected_r auto const world_size = collective::GetWorldSize(); auto const rank = collective::GetRank(); - auto lparam = MakeCUDACtx(rank); + auto ctx = MakeCUDACtx(rank); std::unique_ptr predictor = - std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); + std::unique_ptr(Predictor::Create("gpu_predictor", &ctx)); predictor->Configure({}); for (size_t i = 1; i < 33; i *= 2) { @@ -72,8 +72,6 @@ void VerifyBasicColumnSplit(std::array, 32> const& expected_r auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix(); std::unique_ptr sliced{dmat->SliceCol(world_size, rank)}; - Context ctx; - ctx.gpu_id = rank; LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)}; gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); @@ -95,9 +93,9 @@ TEST(GPUPredictor, MGPUBasicColumnSplit) { GTEST_SKIP() << "Skipping MGPUIBasicColumnSplit test with # GPUs = " << n_gpus; } - auto lparam = MakeCUDACtx(0); + auto ctx = MakeCUDACtx(0); std::unique_ptr predictor = - std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); + std::unique_ptr(Predictor::Create("gpu_predictor", &ctx)); predictor->Configure({}); std::array, 32> result{}; @@ -105,8 +103,6 @@ TEST(GPUPredictor, MGPUBasicColumnSplit) { size_t n_row = i, n_col = i; auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix(); - Context ctx; - ctx.gpu_id = 0; LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)}; gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx); From 96890ec391805682a290ef90b38054209d3e0dc5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 29 Jun 2023 11:17:37 -0700 Subject: [PATCH 7/9] more review feedback --- src/predictor/gpu_predictor.cu | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 9367b14ca94c..1f7f69c3a999 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -15,6 +15,7 @@ #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/common.h" +#include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "../data/device_adapter.cuh" #include "../data/ellpack_page.cuh" @@ -766,7 +767,7 @@ class ColumnSplitHelper { SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); auto const grid = static_cast(common::DivRoundUp(num_rows, kBlockThreads)); - dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} ( + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()} ( MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), @@ -776,7 +777,7 @@ class ColumnSplitHelper { AllReduceBitVectors(&decision_storage, &missing_storage); - dh::LaunchKernel {grid, kBlockThreads} ( + dh::LaunchKernel {grid, kBlockThreads, 0, ctx_->CUDACtx()->Stream()} ( PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), @@ -791,9 +792,9 @@ class ColumnSplitHelper { void AllReduceBitVectors(dh::caching_device_vector* decision_storage, dh::caching_device_vector* missing_storage) const { - collective::AllReduce( + collective::AllReduce( ctx_->gpu_id, decision_storage->data().get(), decision_storage->size()); - collective::AllReduce( // Align to make it easier to read. + collective::AllReduce( ctx_->gpu_id, missing_storage->data().get(), missing_storage->size()); collective::Synchronize(ctx_->gpu_id); } @@ -802,14 +803,20 @@ class ColumnSplitHelper { dh::caching_device_vector* missing_storage, std::size_t total_bits) { auto const size = BitVector::ComputeStorageSize(total_bits); - if (decision_storage->size() < size) { + auto const old_decision_size = decision_storage->size(); + if (old_decision_size < size) { decision_storage->resize(size); } - decision_storage->clear(); - if (missing_storage->size() < size) { + if (old_decision_size != 0) { + thrust::fill(decision_storage->begin(), decision_storage->end(), 0); + } + auto const old_missing_size = missing_storage->size(); + if (old_missing_size < size) { missing_storage->resize(size); } - missing_storage->clear(); + if (old_missing_size != 0) { + thrust::fill(missing_storage->begin(), missing_storage->end(), 0); + } } Context const* ctx_; From 3186816c2b275ed4bee026ea31ee1fb6fdc27075 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 29 Jun 2023 13:32:13 -0700 Subject: [PATCH 8/9] review feedback --- src/predictor/gpu_predictor.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 1f7f69c3a999..2eee14e30678 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -799,23 +799,23 @@ class ColumnSplitHelper { collective::Synchronize(ctx_->gpu_id); } - static void ResizeBitVectors(dh::caching_device_vector* decision_storage, + void ResizeBitVectors(dh::caching_device_vector* decision_storage, dh::caching_device_vector* missing_storage, - std::size_t total_bits) { + std::size_t total_bits) const { auto const size = BitVector::ComputeStorageSize(total_bits); auto const old_decision_size = decision_storage->size(); if (old_decision_size < size) { decision_storage->resize(size); } if (old_decision_size != 0) { - thrust::fill(decision_storage->begin(), decision_storage->end(), 0); + thrust::fill(ctx_->CUDACtx()->CTP(), decision_storage->begin(), decision_storage->end(), 0); } auto const old_missing_size = missing_storage->size(); if (old_missing_size < size) { missing_storage->resize(size); } if (old_missing_size != 0) { - thrust::fill(missing_storage->begin(), missing_storage->end(), 0); + thrust::fill(ctx_->CUDACtx()->CTP(), missing_storage->begin(), missing_storage->end(), 0); } } From 6dbc9ac7347fad3979b7215a07282400b7b6d9b5 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 1 Jul 2023 15:05:46 -0700 Subject: [PATCH 9/9] simplify resizing bit vectors --- src/predictor/gpu_predictor.cu | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 2eee14e30678..4ca0e33fff55 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -803,20 +803,14 @@ class ColumnSplitHelper { dh::caching_device_vector* missing_storage, std::size_t total_bits) const { auto const size = BitVector::ComputeStorageSize(total_bits); - auto const old_decision_size = decision_storage->size(); - if (old_decision_size < size) { + if (decision_storage->size() < size) { decision_storage->resize(size); } - if (old_decision_size != 0) { - thrust::fill(ctx_->CUDACtx()->CTP(), decision_storage->begin(), decision_storage->end(), 0); - } - auto const old_missing_size = missing_storage->size(); - if (old_missing_size < size) { + thrust::fill(ctx_->CUDACtx()->CTP(), decision_storage->begin(), decision_storage->end(), 0); + if (missing_storage->size() < size) { missing_storage->resize(size); } - if (old_missing_size != 0) { - thrust::fill(ctx_->CUDACtx()->CTP(), missing_storage->begin(), missing_storage->end(), 0); - } + thrust::fill(ctx_->CUDACtx()->CTP(), missing_storage->begin(), missing_storage->end(), 0); } Context const* ctx_;