Skip to content

Commit

Permalink
More support for column split in gpu predictor (#9562)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou authored Sep 14, 2023
1 parent a343ae3 commit d8c3cc9
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 124 deletions.
108 changes: 73 additions & 35 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,12 @@ __global__ void MaskBitVectorKernel(
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows,
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) {
// This needs to be always instantiated since the data is loaded cooperatively by all threads.
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx >= num_rows) {
return;
}
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);

std::size_t tree_offset = 0;
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
Expand Down Expand Up @@ -668,10 +669,10 @@ __global__ void MaskBitVectorKernel(
}
}

__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
BitVector const& decision_bits,
BitVector const& missing_bits, std::size_t num_nodes,
std::size_t tree_offset) {
__device__ bst_node_t GetLeafIndexByBitVector(bst_row_t ridx, TreeView const& tree,
BitVector const& decision_bits,
BitVector const& missing_bits, std::size_t num_nodes,
std::size_t tree_offset) {
bst_node_t nidx = 0;
RegTree::Node n = tree.d_tree[nidx];
while (!n.IsLeaf()) {
Expand All @@ -683,9 +684,19 @@ __device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
}
n = tree.d_tree[nidx];
}
return nidx;
}

__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
BitVector const& decision_bits,
BitVector const& missing_bits, std::size_t num_nodes,
std::size_t tree_offset) {
auto const nidx =
GetLeafIndexByBitVector(ridx, tree, decision_bits, missing_bits, num_nodes, tree_offset);
return tree.d_tree[nidx].LeafValue();
}

template <bool predict_leaf>
__global__ void PredictByBitVectorKernel(
common::Span<RegTree::Node const> d_nodes, common::Span<float> d_out_predictions,
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
Expand All @@ -701,27 +712,39 @@ __global__ void PredictByBitVectorKernel(
}

std::size_t tree_offset = 0;
if (num_group == 1) {
float sum = 0;
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
if constexpr (predict_leaf) {
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
sum += GetLeafWeightByBitVector(row_idx, d_tree, decision_bits, missing_bits, num_nodes,
tree_offset);
auto const leaf = GetLeafIndexByBitVector(row_idx, d_tree, decision_bits, missing_bits,
num_nodes, tree_offset);
d_out_predictions[row_idx * (tree_end - tree_begin) + tree_idx] = static_cast<float>(leaf);
tree_offset += d_tree.d_tree.size();
}
d_out_predictions[row_idx] += sum;
} else {
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto const tree_group = d_tree_group[tree_idx];
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
bst_uint out_prediction_idx = row_idx * num_group + tree_group;
d_out_predictions[out_prediction_idx] += GetLeafWeightByBitVector(
row_idx, d_tree, decision_bits, missing_bits, num_nodes, tree_offset);
tree_offset += d_tree.d_tree.size();
if (num_group == 1) {
float sum = 0;
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
sum += GetLeafWeightByBitVector(row_idx, d_tree, decision_bits, missing_bits, num_nodes,
tree_offset);
tree_offset += d_tree.d_tree.size();
}
d_out_predictions[row_idx] += sum;
} else {
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto const tree_group = d_tree_group[tree_idx];
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
bst_uint out_prediction_idx = row_idx * num_group + tree_group;
d_out_predictions[out_prediction_idx] += GetLeafWeightByBitVector(
row_idx, d_tree, decision_bits, missing_bits, num_nodes, tree_offset);
tree_offset += d_tree.d_tree.size();
}
}
}
}
Expand All @@ -733,13 +756,21 @@ class ColumnSplitHelper {
void PredictBatch(DMatrix* dmat, HostDeviceVector<float>* out_preds,
gbm::GBTreeModel const& model, DeviceModel const& d_model) const {
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
PredictDMatrix(dmat, out_preds, d_model, model.learner_model_param->num_feature,
model.learner_model_param->num_output_group);
PredictDMatrix<false>(dmat, out_preds, d_model, model.learner_model_param->num_feature,
model.learner_model_param->num_output_group);
}

void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds, gbm::GBTreeModel const& model,
DeviceModel const& d_model) const {
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
PredictDMatrix<true>(dmat, out_preds, d_model, model.learner_model_param->num_feature,
model.learner_model_param->num_output_group);
}

private:
using BitType = BitVector::value_type;

template <bool predict_leaf>
void PredictDMatrix(DMatrix* dmat, HostDeviceVector<float>* out_preds, DeviceModel const& model,
bst_feature_t num_features, std::uint32_t num_group) const {
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
Expand Down Expand Up @@ -777,7 +808,7 @@ class ColumnSplitHelper {
AllReduceBitVectors(&decision_storage, &missing_storage);

dh::LaunchKernel {grid, kBlockThreads, 0, ctx_->CUDACtx()->Stream()} (
PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(),
PredictByBitVectorKernel<predict_leaf>, model.nodes.ConstDeviceSpan(),
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(),
Expand All @@ -795,12 +826,11 @@ class ColumnSplitHelper {
ctx_->gpu_id, decision_storage->data().get(), decision_storage->size());
collective::AllReduce<collective::Operation::kBitwiseAND>(
ctx_->gpu_id, missing_storage->data().get(), missing_storage->size());
collective::Synchronize(ctx_->gpu_id);
}

void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,
dh::caching_device_vector<BitType>* missing_storage,
std::size_t total_bits) const {
dh::caching_device_vector<BitType>* missing_storage,
std::size_t total_bits) const {
auto const size = BitVector::ComputeStorageSize(total_bits);
if (decision_storage->size() < size) {
decision_storage->resize(size);
Expand Down Expand Up @@ -889,7 +919,7 @@ class GPUPredictor : public xgboost::Predictor {
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);

if (dmat->Info().IsColumnSplit()) {
if (info.IsColumnSplit()) {
column_split_helper_.PredictBatch(dmat, out_preds, model, d_model);
return;
}
Expand Down Expand Up @@ -1018,6 +1048,9 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
CHECK(!p_fmat->Info().IsColumnSplit())
<< "Predict contribution support for column-wise data split is not yet implemented.";

dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
out_contribs->SetDevice(ctx_->gpu_id);
if (tree_end == 0 || tree_end > model.trees.size()) {
Expand Down Expand Up @@ -1136,17 +1169,9 @@ class GPUPredictor : public xgboost::Predictor {
const gbm::GBTreeModel &model,
unsigned tree_end) const override {
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);

const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128;
size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
info.num_col_, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_;
bst_row_t num_rows = info.num_row_;
size_t entry_start = 0;

if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
Expand All @@ -1155,6 +1180,19 @@ class GPUPredictor : public xgboost::Predictor {
DeviceModel d_model;
d_model.Init(model, 0, tree_end, this->ctx_->gpu_id);

if (info.IsColumnSplit()) {
column_split_helper_.PredictLeaf(p_fmat, predictions, model, d_model);
return;
}

auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
constexpr uint32_t kBlockThreads = 128;
size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
info.num_col_, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_;
size_t entry_start = 0;

if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(ctx_->gpu_id);
Expand Down
28 changes: 13 additions & 15 deletions tests/cpp/predictor/test_cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ TEST(CpuPredictor, IterationRange) {
}

TEST(CpuPredictor, IterationRangeColmnSplit) {
Context ctx;
TestIterationRangeColumnSplit(&ctx);
auto constexpr kWorldSize = 2;
TestIterationRangeColumnSplit(kWorldSize, false);
}

TEST(CpuPredictor, ExternalMemory) {
Expand Down Expand Up @@ -226,23 +226,21 @@ TEST(CPUPredictor, GHistIndexTraining) {
}

TEST(CPUPredictor, CategoricalPrediction) {
Context ctx;
TestCategoricalPrediction(&ctx, false);
TestCategoricalPrediction(false, false);
}

TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
Context ctx;
TestCategoricalPredictionColumnSplit(&ctx);
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, false, true);
}

TEST(CPUPredictor, CategoricalPredictLeaf) {
Context ctx;
TestCategoricalPredictLeaf(&ctx, false);
TestCategoricalPredictLeaf(false, false);
}

TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
Context ctx;
TestCategoricalPredictLeafColumnSplit(&ctx);
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, false, true);
}

TEST(CpuPredictor, UpdatePredictionCache) {
Expand All @@ -256,8 +254,8 @@ TEST(CpuPredictor, LesserFeatures) {
}

TEST(CpuPredictor, LesserFeaturesColumnSplit) {
Context ctx;
TestPredictionWithLesserFeaturesColumnSplit(&ctx);
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestPredictionWithLesserFeaturesColumnSplit, false);
}

TEST(CpuPredictor, Sparse) {
Expand All @@ -267,9 +265,9 @@ TEST(CpuPredictor, Sparse) {
}

TEST(CpuPredictor, SparseColumnSplit) {
Context ctx;
TestSparsePredictionColumnSplit(&ctx, 0.2);
TestSparsePredictionColumnSplit(&ctx, 0.8);
auto constexpr kWorldSize = 2;
TestSparsePredictionColumnSplit(kWorldSize, false, 0.2);
TestSparsePredictionColumnSplit(kWorldSize, false, 0.8);
}

TEST(CpuPredictor, Multi) {
Expand Down
27 changes: 23 additions & 4 deletions tests/cpp/predictor/test_gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ TEST(GpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures(&ctx);
}

TEST_F(MGPUPredictorTest, LesserFeaturesColumnSplit) {
RunWithInMemoryCommunicator(world_size_, TestPredictionWithLesserFeaturesColumnSplit, true);
}

// Very basic test of empty model
TEST(GPUPredictor, ShapStump) {
cudaSetDevice(0);
Expand Down Expand Up @@ -270,14 +274,24 @@ TEST(GPUPredictor, IterationRange) {
TestIterationRange(&ctx);
}

TEST_F(MGPUPredictorTest, IterationRangeColumnSplit) {
TestIterationRangeColumnSplit(world_size_, true);
}

TEST(GPUPredictor, CategoricalPrediction) {
auto ctx = MakeCUDACtx(0);
TestCategoricalPrediction(&ctx, false);
TestCategoricalPrediction(true, false);
}

TEST_F(MGPUPredictorTest, CategoricalPredictionColumnSplit) {
RunWithInMemoryCommunicator(world_size_, TestCategoricalPrediction, true, true);
}

TEST(GPUPredictor, CategoricalPredictLeaf) {
auto ctx = MakeCUDACtx(0);
TestCategoricalPredictLeaf(&ctx, false);
TestCategoricalPredictLeaf(true, false);
}

TEST_F(MGPUPredictorTest, CategoricalPredictionLeafColumnSplit) {
RunWithInMemoryCommunicator(world_size_, TestCategoricalPredictLeaf, true, true);
}

TEST(GPUPredictor, PredictLeafBasic) {
Expand Down Expand Up @@ -305,4 +319,9 @@ TEST(GPUPredictor, Sparse) {
TestSparsePrediction(&ctx, 0.2);
TestSparsePrediction(&ctx, 0.8);
}

TEST_F(MGPUPredictorTest, SparseColumnSplit) {
TestSparsePredictionColumnSplit(world_size_, true, 0.2);
TestSparsePredictionColumnSplit(world_size_, true, 0.8);
}
} // namespace xgboost::predictor
Loading

0 comments on commit d8c3cc9

Please sign in to comment.