Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support column split in GPU predictor #9343

Merged
merged 12 commits into from
Jul 2, 2023
5 changes: 3 additions & 2 deletions src/collective/nccl_device_communicator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,11 @@ template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size, cudaStream_t stream) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
out_buffer[idx] = device_buffer[idx];
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]);
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace
Expand Down
17 changes: 5 additions & 12 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ class ColumnSplitHelper {
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
auto const &tree = *model_.trees[tree_id];
auto const &cats = tree.GetCategoriesMatrix();
auto const has_categorical = tree.HasCategoricalSplit();
bst_node_t n_nodes = tree.GetNodes().size();

for (bst_node_t nid = 0; nid < n_nodes; nid++) {
Expand All @@ -484,16 +483,10 @@ class ColumnSplitHelper {
}

auto const fvalue = feat.GetFvalue(split_index);
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto const node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
if (!common::Decision(node_categories, fvalue)) {
decision_bits_.Set(bit_index);
}
continue;
}

if (fvalue >= node.SplitCond()) {
auto const decision = tree.HasCategoricalSplit()
? GetDecision<true>(node, nid, fvalue, cats)
: GetDecision<false>(node, nid, fvalue, cats);
if (decision) {
decision_bits_.Set(bit_index);
}
}
Expand All @@ -511,7 +504,7 @@ class ColumnSplitHelper {
if (missing_bits_.Check(bit_index)) {
return node.DefaultChild();
} else {
return node.LeftChild() + decision_bits_.Check(bit_index);
return node.LeftChild() + !decision_bits_.Check(bit_index);
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
208 changes: 204 additions & 4 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <any> // for any, any_cast
#include <memory>

#include "../collective/communicator-inl.cuh"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
Expand Down Expand Up @@ -110,13 +111,11 @@ struct SparsePageLoader {
bool use_shared;
SparsePageView data;
float* smem;
size_t entry_start;

__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start, float)
: use_shared(use_shared),
data(data),
entry_start(entry_start) {
data(data) {
extern __shared__ float _smem[];
smem = _smem;
// Copy instances
Expand Down Expand Up @@ -622,6 +621,199 @@ size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
}
return shared_memory_bytes;
}

using BitVector = LBitField64;

__global__ void MaskBitVectorKernel(
SparsePageView data, common::Span<RegTree::Node const> d_nodes,
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
common::Span<FeatureType const> d_tree_split_types,
common::Span<std::uint32_t const> d_cat_tree_segments,
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows,
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) {
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx >= num_rows) {
return;
}
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);

std::size_t tree_offset = 0;
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
auto const tree_nodes = d_tree.d_tree.size();
for (auto nid = 0; nid < tree_nodes; nid++) {
auto const& node = d_tree.d_tree[nid];
if (node.IsDeleted() || node.IsLeaf()) {
continue;
}
auto const fvalue = loader.GetElement(row_idx, node.SplitIndex());
auto const is_missing = common::CheckNAN(fvalue);
auto const bit_index = row_idx * num_nodes + tree_offset + nid;
if (is_missing) {
missing_bits.Set(bit_index);
} else {
auto const decision = d_tree.HasCategoricalSplit()
? GetDecision<true>(node, nid, fvalue, d_tree.cats)
: GetDecision<false>(node, nid, fvalue, d_tree.cats);
if (decision) {
decision_bits.Set(bit_index);
}
}
}
tree_offset += tree_nodes;
}
}

__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
BitVector const& decision_bits,
BitVector const& missing_bits, std::size_t num_nodes,
std::size_t tree_offset) {
bst_node_t nidx = 0;
RegTree::Node n = tree.d_tree[nidx];
while (!n.IsLeaf()) {
auto const bit_index = ridx * num_nodes + tree_offset + nidx;
if (missing_bits.Check(bit_index)) {
nidx = n.DefaultChild();
} else {
nidx = n.LeftChild() + !decision_bits.Check(bit_index);
}
n = tree.d_tree[nidx];
}
return tree.d_tree[nidx].LeafValue();
}

__global__ void PredictByBitVectorKernel(
common::Span<RegTree::Node const> d_nodes, common::Span<float> d_out_predictions,
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
common::Span<FeatureType const> d_tree_split_types,
common::Span<std::uint32_t const> d_cat_tree_segments,
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
std::size_t tree_begin, std::size_t tree_end, std::size_t num_rows, std::size_t num_nodes,
std::uint32_t num_group) {
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx >= num_rows) {
return;
}

std::size_t tree_offset = 0;
if (num_group == 1) {
float sum = 0;
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
sum += GetLeafWeightByBitVector(row_idx, d_tree, decision_bits, missing_bits, num_nodes,
tree_offset);
tree_offset += d_tree.d_tree.size();
}
d_out_predictions[row_idx] += sum;
} else {
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto const tree_group = d_tree_group[tree_idx];
TreeView d_tree{tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
bst_uint out_prediction_idx = row_idx * num_group + tree_group;
d_out_predictions[out_prediction_idx] += GetLeafWeightByBitVector(
row_idx, d_tree, decision_bits, missing_bits, num_nodes, tree_offset);
tree_offset += d_tree.d_tree.size();
}
}
}

class ColumnSplitHelper {
public:
explicit ColumnSplitHelper(std::int32_t gpu_id) : gpu_id_{gpu_id} {}

void PredictBatch(DMatrix* dmat, HostDeviceVector<float>* out_preds,
gbm::GBTreeModel const& model, DeviceModel const& d_model) const {
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
PredictDMatrix(dmat, out_preds, d_model, model.learner_model_param->num_feature,
model.learner_model_param->num_output_group);
}

private:
using BitType = BitVector::value_type;

void PredictDMatrix(DMatrix* dmat, HostDeviceVector<float>* out_preds, DeviceModel const& model,
bst_feature_t num_features, std::uint32_t num_group) const {
dh::safe_cuda(cudaSetDevice(gpu_id_));
dh::caching_device_vector<BitType> decision_storage{};
dh::caching_device_vector<BitType> missing_storage{};

auto constexpr kBlockThreads = 128;
auto const max_shared_memory_bytes = dh::MaxSharedMemory(gpu_id_);
auto const shared_memory_bytes =
SharedMemoryBytes<kBlockThreads>(num_features, max_shared_memory_bytes);
auto const use_shared = shared_memory_bytes != 0;

auto const num_nodes = model.nodes.Size();
std::size_t batch_offset = 0;
for (auto const& batch : dmat->GetBatches<SparsePage>()) {
auto const num_rows = batch.Size();
ResizeBitVectors(&decision_storage, &missing_storage, num_rows * num_nodes);
BitVector decision_bits{dh::ToSpan(decision_storage)};
BitVector missing_bits{dh::ToSpan(missing_storage)};

batch.offset.SetDevice(gpu_id_);
batch.data.SetDevice(gpu_id_);
std::size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features);

auto const grid = static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(),
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows,
entry_start, num_nodes, use_shared, nan(""));

AllReduceBitVectors(&decision_storage, &missing_storage);

dh::LaunchKernel {grid, kBlockThreads} (
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(),
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_rows, num_nodes,
num_group);

batch_offset += batch.Size() * num_group;
}
}

void AllReduceBitVectors(dh::caching_device_vector<BitType>* decision_storage,
dh::caching_device_vector<BitType>* missing_storage) const {
collective::AllReduce<collective::Operation::kBitwiseAND>(
gpu_id_, decision_storage->data().get(), decision_storage->size());
collective::AllReduce<collective::Operation::kBitwiseOR>( // Align to make it easier to read.
gpu_id_, missing_storage->data().get(), missing_storage->size());
collective::Synchronize(gpu_id_);
}

static void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,
dh::caching_device_vector<BitType>* missing_storage,
std::size_t total_bits) {
auto const size = BitVector::ComputeStorageSize(total_bits);
if (decision_storage->size() < size) {
decision_storage->resize(size);
}
decision_storage->clear();
if (missing_storage->size() < size) {
missing_storage->resize(size);
}
missing_storage->clear();
}

std::int32_t const gpu_id_;
};
} // anonymous namespace

class GPUPredictor : public xgboost::Predictor {
Expand Down Expand Up @@ -697,6 +889,11 @@ class GPUPredictor : public xgboost::Predictor {
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);

if (dmat->Info().IsColumnSplit()) {
column_split_helper_.PredictBatch(dmat, out_preds, model, d_model);
return;
}

if (dmat->PageExists<SparsePage>()) {
size_t batch_offset = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
Expand All @@ -720,7 +917,8 @@ class GPUPredictor : public xgboost::Predictor {
}

public:
explicit GPUPredictor(Context const* ctx) : Predictor::Predictor{ctx} {}
explicit GPUPredictor(Context const* ctx)
: Predictor::Predictor{ctx}, column_split_helper_{ctx->gpu_id} {}
trivialfis marked this conversation as resolved.
Show resolved Hide resolved

~GPUPredictor() override {
if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
Expand Down Expand Up @@ -1019,6 +1217,8 @@ class GPUPredictor : public xgboost::Predictor {
}
return 0;
}

ColumnSplitHelper column_split_helper_;
};

XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
Expand Down
20 changes: 13 additions & 7 deletions src/predictor/predict_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,26 @@
#include "xgboost/tree_model.h"

namespace xgboost::predictor {
/** @brief Whether it should traverse to the left branch of a tree. */
template <bool has_categorical>
inline XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue,
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
RegTree::CategoricalSplitMatrix const &cats) {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision(node_categories, fvalue);
} else {
return fvalue < node.SplitCond();
}
}

template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
float fvalue, bool is_missing,
RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) {
return node.DefaultChild();
} else {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories =
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
} else {
return node.LeftChild() + !(fvalue < node.SplitCond());
}
return node.LeftChild() + !GetDecision<has_categorical>(node, nid, fvalue, cats);
}
}

Expand Down
Loading