Skip to content

Commit

Permalink
Share the batch parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 20, 2024
1 parent dab1382 commit bc571c1
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 26 deletions.
24 changes: 24 additions & 0 deletions src/data/batch_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/**
* Copyright 2024, XGBoost Contributors
*/
#pragma once

#include "xgboost/data.h" // for BatchParam

namespace xgboost::data::cuda_impl {
// Use two batch for prefecting. There's always one batch being worked on, while the other
// batch being transferred.
constexpr auto DftPrefetchBatches() { return 2; }

// Empty parameter to prevent regen, only used to control external memory prefetching.
//
// Both the approx and hist initializes the DMatrix before creating the actual
// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty
// parameter to avoid any regen.
inline BatchParam StaticBatch(bool prefetch_copy) {
BatchParam p;
p.prefetch_copy = prefetch_copy;
p.n_prefetch_batches = DftPrefetchBatches();
return p;
}
} // namespace xgboost::data::cuda_impl
13 changes: 8 additions & 5 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/cuda_rt_utils.h" // for AllVisibleGPUs
#include "../common/device_helpers.cuh"
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../data/batch_utils.cuh" // for StaticBatch
#include "../data/device_adapter.cuh"
#include "../data/ellpack_page.cuh"
#include "../data/proxy_dmatrix.h"
Expand All @@ -31,6 +32,8 @@
namespace xgboost::predictor {
DMLC_REGISTRY_FILE_TAG(gpu_predictor);

using data::cuda_impl::StaticBatch;

struct TreeView {
RegTree::CategoricalSplitMatrix cats;
common::Span<RegTree::Node const> d_tree;
Expand Down Expand Up @@ -924,7 +927,7 @@ class GPUPredictor : public xgboost::Predictor {
}
} else {
bst_idx_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
dmat->Info().feature_types.SetDevice(ctx_->Device());
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_, feature_types), d_model,
Expand Down Expand Up @@ -1067,7 +1070,7 @@ class GPUPredictor : public xgboost::Predictor {
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
}
} else {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_)};
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
std::numeric_limits<float>::quiet_NaN()};
Expand Down Expand Up @@ -1137,7 +1140,7 @@ class GPUPredictor : public xgboost::Predictor {
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
}
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
auto impl = batch.Impl();
auto acc = impl->GetDeviceAccessor(ctx_, p_fmat->Info().feature_types.ConstDeviceSpan());
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
Expand Down Expand Up @@ -1223,7 +1226,7 @@ class GPUPredictor : public xgboost::Predictor {
}
} else {
bst_idx_t batch_offset = 0;
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, StaticBatch(true))) {
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_)};
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset);
Expand Down
26 changes: 10 additions & 16 deletions src/tree/updater_gpu_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include <limits> // for numeric_limits
#include <ostream> // for ostream

#include "gpu_hist/quantiser.cuh" // for GradientQuantiser
#include "param.h" // for TrainParam
#include "xgboost/base.h" // for bst_bin_t
#include "xgboost/task.h" // for ObjInfo
#include "../data/batch_utils.cuh" // for DftPrefetchBatches, StaticBatch
#include "gpu_hist/quantiser.cuh" // for GradientQuantiser
#include "param.h" // for TrainParam
#include "xgboost/base.h" // for bst_bin_t
#include "xgboost/task.h" // for ObjInfo

namespace xgboost::tree {
struct GPUTrainingParam {
Expand Down Expand Up @@ -119,26 +120,19 @@ struct DeviceSplitCandidate {
};

namespace cuda_impl {
constexpr auto DftPrefetchBatches() { return 2; }

inline BatchParam HistBatch(TrainParam const& param) {
auto p = BatchParam{param.max_bin, TrainParam::DftSparseThreshold()};
p.prefetch_copy = true;
p.n_prefetch_batches = DftPrefetchBatches();
p.n_prefetch_batches = data::cuda_impl::DftPrefetchBatches();
return p;
}

inline BatchParam ApproxBatch(TrainParam const& p, common::Span<float const> hess,
ObjInfo const& task) {
return BatchParam{p.max_bin, hess, !task.const_hess};
}

// Empty parameter to prevent regen, only used to control external memory prefetching.
inline BatchParam StaticBatch(bool prefetch_copy) {
BatchParam p;
p.prefetch_copy = prefetch_copy;
p.n_prefetch_batches = DftPrefetchBatches();
return p;
auto batch = BatchParam{p.max_bin, hess, !task.const_hess};
batch.prefetch_copy = true;
batch.n_prefetch_batches = data::cuda_impl::DftPrefetchBatches();
return batch;
}
} // namespace cuda_impl

Expand Down
7 changes: 2 additions & 5 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "../common/hist_util.h" // for HistogramCuts
#include "../common/random.h" // for ColumnSampler, GlobalRandom
#include "../common/timer.h"
#include "../data/batch_utils.cuh" // for StaticBatch
#include "../data/ellpack_page.cuh"
#include "../data/ellpack_page.h"
#include "constraints.cuh"
Expand Down Expand Up @@ -50,11 +51,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);

using cuda_impl::ApproxBatch;
using cuda_impl::HistBatch;

// Both the approx and hist initializes the DMatrix before creating the actual
// implementation (InitDataOnce). Therefore, the `GPUHistMakerDevice` can use an empty
// parameter to avoid any regen.
using cuda_impl::StaticBatch;
using data::cuda_impl::StaticBatch;

// Extra data for each node that is passed to the update position function
struct NodeSplitData {
Expand Down

0 comments on commit bc571c1

Please sign in to comment.