diff --git a/include/xgboost/context.h b/include/xgboost/context.h index b11ca70ec06c..de7648079a30 100644 --- a/include/xgboost/context.h +++ b/include/xgboost/context.h @@ -9,9 +9,10 @@ #include // for CHECK_GE #include // for XGBoostParameter -#include // for int16_t, int32_t, int64_t -#include // for shared_ptr -#include // for string, to_string +#include // for int16_t, int32_t, int64_t +#include // for shared_ptr +#include // for string, to_string +#include // for invoke_result_t, is_same_v namespace xgboost { @@ -152,6 +153,25 @@ struct Context : public XGBoostParameter { ctx.gpu_id = kCpuId; return ctx; } + /** + * @brief Call function based on the current device. + */ + template + decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const { + static_assert(std::is_same_v, std::invoke_result_t>); + switch (this->Device().device) { + case DeviceOrd::kCPU: + return cpu_fn(); + case DeviceOrd::kCUDA: + return cuda_fn(); + default: + // Do not use the device name as this is likely an internal error, the name + // wouldn't be valid. + LOG(FATAL) << "Unknown device type:" << static_cast(this->Device().device); + break; + } + return std::invoke_result_t(); + } // declare parameters DMLC_DECLARE_PARAMETER(Context) { diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 615bc0f398bc..f0d2e8e37b4f 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -6,24 +6,22 @@ */ #pragma once #include -#include // DMatrixCache +#include // for DMatrixCache +#include // for Context #include #include #include -#include // std::function -#include +#include // for function +#include // for shared_ptr #include -#include // for get_id #include // for make_pair #include // Forward declarations -namespace xgboost { -namespace gbm { +namespace xgboost::gbm { struct GBTreeModel; -} // namespace gbm -} // namespace xgboost +} // namespace xgboost::gbm namespace xgboost { /** diff --git a/src/common/error_msg.h b/src/common/error_msg.h index e690a12f33a2..e9b9fc56ba29 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -47,5 +47,9 @@ inline void MaxFeatureSize(std::uint64_t n_features) { << "Unfortunately, XGBoost does not support data matrices with " << std::numeric_limits::max() << " features or greater"; } + +constexpr StringView InplacePredictProxy() { + return "Inplace predict accepts only DMatrixProxy as input."; +} } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc index ac52c0697304..241abfb1f644 100644 --- a/src/data/gradient_index_format.cc +++ b/src/data/gradient_index_format.cc @@ -68,6 +68,7 @@ class GHistIndexRawFormat : public SparsePageFormat { } std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override { + CHECK_NE(page.index.Size(), 0) << "Empty page is not supported."; std::size_t bytes = 0; bytes += WriteHistogramCuts(page.cut, fo); // indptr diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 6fa2f07e0ddd..1b2ed3fddd4d 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -1,10 +1,9 @@ -/*! - * Copyright 2021-2022 by XGBoost Contributors +/** + * Copyright 2021-2023, XGBoost Contributors */ #include "gradient_index_page_source.h" -namespace xgboost { -namespace data { +namespace xgboost::data { void GradientIndexPageSource::Fetch() { if (!this->ReadCache()) { if (count_ != 0 && !sync_) { @@ -21,5 +20,4 @@ void GradientIndexPageSource::Fetch() { this->WriteCache(); } } -} // namespace data -} // namespace xgboost +} // namespace xgboost::data diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 9d595c378f3e..b5c1573b1b27 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -18,7 +18,7 @@ #include #include "../common/common.h" -#include "../common/error_msg.h" // for UnknownDevice +#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy #include "../common/random.h" #include "../common/threading_utils.h" #include "../common/timer.h" @@ -542,6 +542,18 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, } } +namespace { +inline void MismatchedDevices(Context const* booster, Context const* data) { + LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost " + << "is running on: " << booster->DeviceName() + << ", while the input data is on: " << data->DeviceName() << ".\n" + << R"(Potential solutions: +- Use a data structure that matches the device ordinal in the booster. +- Set the device for booster before call to inplace_predict. +)"; +} +}; // namespace + void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training, bst_layer_t layer_begin, bst_layer_t layer_end) { // dispatch to const function. @@ -555,24 +567,26 @@ void GBTree::InplacePredict(std::shared_ptr p_m, float missing, auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end); CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; if (p_m->Ctx()->Device() != this->ctx_->Device()) { - LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost " - << "is running on: " << this->ctx_->DeviceName() - << ", while the input data is on: " << p_m->Ctx()->DeviceName() << "."; + MismatchedDevices(this->ctx_, p_m->Ctx()); CHECK_EQ(out_preds->version, 0); auto proxy = std::dynamic_pointer_cast(p_m); - auto any_adapter = proxy->Adapter(); + CHECK(proxy) << error::InplacePredictProxy(); auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing); this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end); return; } - if (this->ctx_->IsCPU()) { - this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end); - } else if (p_m->Ctx()->IsCUDA()) { - CHECK(this->gpu_predictor_); - this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end); - } else { - LOG(FATAL) << error::UnknownDevice(); + bool known_type = this->ctx_->DispatchDevice( + [&, begin = tree_begin, end = tree_end] { + return this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end); + }, + [&, begin = tree_begin, end = tree_end] { + return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end); + }); + if (!known_type) { + auto proxy = std::dynamic_pointer_cast(p_m); + CHECK(proxy) << error::InplacePredictProxy(); + LOG(FATAL) << "Unknown data type for inplace prediction:" << proxy->Adapter().type().name(); } } @@ -808,11 +822,9 @@ class Dart : public GBTree { auto n_groups = model_.learner_model_param->num_output_group; if (ctx_->Device() != p_fmat->Ctx()->Device()) { - LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost " - << "is running on: " << this->ctx_->DeviceName() - << ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << "."; + MismatchedDevices(ctx_, p_fmat->Ctx()); auto proxy = std::dynamic_pointer_cast(p_fmat); - auto any_adapter = proxy->Adapter(); + CHECK(proxy) << error::InplacePredictProxy(); auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing); this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end); return; @@ -825,20 +837,15 @@ class Dart : public GBTree { } predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0); - auto get_predictor = [&]() -> Predictor const* { - if (ctx_->IsCPU()) { - return cpu_predictor_.get(); - } else if (ctx_->IsCUDA()) { - CHECK(this->gpu_predictor_); - return gpu_predictor_.get(); - } else { - LOG(FATAL) << error::UnknownDevice(); - return nullptr; - } - }; auto predict_impl = [&](size_t i) { predts.predictions.Fill(0); - bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)}; + bool success = this->ctx_->DispatchDevice( + [&] { + return cpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1); + }, + [&] { + return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1); + }); CHECK(success) << msg; }; @@ -846,7 +853,15 @@ class Dart : public GBTree { for (bst_tree_t i = tree_begin; i < tree_end; ++i) { predict_impl(i); if (i == tree_begin) { - get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_); + this->ctx_->DispatchDevice( + [&] { + this->cpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, + model_); + }, + [&] { + this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, + model_); + }); } // Multiple the tree weight auto w = this->weight_drop_.at(i); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index b9cb02d56139..c092c0b04588 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -16,6 +16,7 @@ #include "../common/bitfield.h" // for RBitField8 #include "../common/categorical.h" // for IsCat, Decision #include "../common/common.h" // for DivRoundUp +#include "../common/error_msg.h" // for InplacePredictProxy #include "../common/math.h" // for CheckNAN #include "../common/threading_utils.h" // for ParallelFor #include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter @@ -741,7 +742,7 @@ class CPUPredictor : public Predictor { PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { auto proxy = dynamic_cast(p_m.get()); - CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input."; + CHECK(proxy)<< error::InplacePredictProxy(); CHECK(!p_m->Info().IsColumnSplit()) << "Inplace predict support for column-wise data split is not yet implemented."; auto x = proxy->Adapter(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 4ca0e33fff55..578fda180f37 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -15,8 +15,9 @@ #include "../common/bitfield.h" #include "../common/categorical.h" #include "../common/common.h" -#include "../common/cuda_context.cuh" +#include "../common/cuda_context.cuh" // for CUDAContext #include "../common/device_helpers.cuh" +#include "../common/error_msg.h" // for InplacePredictProxy #include "../data/device_adapter.cuh" #include "../data/ellpack_page.cuh" #include "../data/proxy_dmatrix.h" @@ -989,7 +990,7 @@ class GPUPredictor : public xgboost::Predictor { PredictionCacheEntry* out_preds, uint32_t tree_begin, unsigned tree_end) const override { auto proxy = dynamic_cast(p_m.get()); - CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input."; + CHECK(proxy) << error::InplacePredictProxy(); auto x = proxy->Adapter(); if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict dmat = CreateSparsePageDMatrix(10000); + auto p_fmat = RandomDataGenerator{4096, 256, 0.5} + .Device(ctx.gpu_id) + .Batches(8) + .GenerateSparsePageDMatrix("cache", true); + std::vector base_rowids; - std::vector hessian(dmat->Info().num_row_, 1); - for (auto const &page : dmat->GetBatches(&ctx, {64, hessian, true})) { + std::vector hessian(p_fmat->Info().num_row_, 1); + for (auto const &page : p_fmat->GetBatches(&ctx, {64, hessian, true})) { base_rowids.push_back(page.base_rowid); } - size_t i = 0; - for (auto const &page : dmat->GetBatches()) { + + std::size_t i = 0; + for (auto const &page : p_fmat->GetBatches()) { ASSERT_EQ(base_rowids[i], page.base_rowid); ++i; } base_rowids.clear(); - for (auto const &page : dmat->GetBatches(&ctx, {64, hessian, false})) { + for (auto const &page : p_fmat->GetBatches(&ctx, {64, hessian, false})) { base_rowids.push_back(page.base_rowid); } i = 0; - for (auto const &page : dmat->GetBatches()) { + for (auto const &page : p_fmat->GetBatches()) { ASSERT_EQ(base_rowids[i], page.base_rowid); ++i; } diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index d1e9e624252c..839ea762e970 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -76,9 +76,11 @@ TEST(SparsePageDMatrix, LoadFile) { // allow caller to retain pages so they can process multiple pages at the same time. template void TestRetainPage() { - auto m = CreateSparsePageDMatrix(10000); + std::size_t n_batches = 4; + auto p_fmat = RandomDataGenerator{1024, 128, 0.5f}.Batches(n_batches).GenerateSparsePageDMatrix( + "cache", true); Context ctx; - auto batches = m->GetBatches(&ctx); + auto batches = p_fmat->GetBatches(&ctx); auto begin = batches.begin(); auto end = batches.end(); @@ -94,7 +96,7 @@ void TestRetainPage() { } ASSERT_EQ(pages.back().Size(), (*it).Size()); } - ASSERT_GE(iterators.size(), 2); + ASSERT_GE(iterators.size(), n_batches); for (size_t i = 0; i < iterators.size(); ++i) { ASSERT_EQ((*iterators[i]).Size(), pages.at(i).Size()); @@ -102,7 +104,7 @@ void TestRetainPage() { } // make sure it's const and the caller can not modify the content of page. - for (auto &page : m->GetBatches({&ctx})) { + for (auto &page : p_fmat->GetBatches({&ctx})) { static_assert(std::is_const>::value); } } diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 2bc0b2c6bc10..f57b1f47ccce 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -514,4 +514,86 @@ TEST(GBTree, PredictRange) { dmlc::Error); } } + +TEST(GBTree, InplacePredictionError) { + std::size_t n_samples{2048}, n_features{32}; + + auto test_ext_err = [&](std::string booster, Context const* ctx) { + std::shared_ptr p_fmat = + RandomDataGenerator{n_samples, n_features, 0.5f}.Batches(2).GenerateSparsePageDMatrix( + "cache", true); + std::unique_ptr learner{Learner::Create({p_fmat})}; + learner->SetParam("booster", booster); + ConfigLearnerByCtx(ctx, learner.get()); + learner->Configure(); + for (std::int32_t i = 0; i < 3; ++i) { + learner->UpdateOneIter(i, p_fmat); + } + HostDeviceVector* out_predt; + ASSERT_THROW( + { + learner->InplacePredict(p_fmat, PredictionType::kValue, + std::numeric_limits::quiet_NaN(), &out_predt, 0, 0); + }, + dmlc::Error); + }; + + { + Context ctx; + test_ext_err("gbtree", &ctx); + test_ext_err("dart", &ctx); + } + +#if defined(XGBOOST_USE_CUDA) + { + auto ctx = MakeCUDACtx(0); + test_ext_err("gbtree", &ctx); + test_ext_err("dart", &ctx); + } +#endif // defined(XGBOOST_USE_CUDA) + + auto test_qdm_err = [&](std::string booster, Context const* ctx) { + std::shared_ptr p_fmat; + bst_bin_t max_bins = 16; + auto rng = RandomDataGenerator{n_samples, n_features, 0.5f}.Device(ctx->gpu_id).Bins(max_bins); + if (ctx->IsCPU()) { + p_fmat = rng.GenerateQuantileDMatrix(true); + } else { +#if defined(XGBOOST_USE_CUDA) + p_fmat = rng.GenerateDeviceDMatrix(true); +#else + CHECK(p_fmat); +#endif // defined(XGBOOST_USE_CUDA) + }; + std::unique_ptr learner{Learner::Create({p_fmat})}; + learner->SetParam("booster", booster); + learner->SetParam("max_bin", std::to_string(max_bins)); + ConfigLearnerByCtx(ctx, learner.get()); + learner->Configure(); + for (std::int32_t i = 0; i < 3; ++i) { + learner->UpdateOneIter(i, p_fmat); + } + HostDeviceVector* out_predt; + ASSERT_THROW( + { + learner->InplacePredict(p_fmat, PredictionType::kValue, + std::numeric_limits::quiet_NaN(), &out_predt, 0, 0); + }, + dmlc::Error); + }; + + { + Context ctx; + test_qdm_err("gbtree", &ctx); + test_qdm_err("dart", &ctx); + } + +#if defined(XGBOOST_USE_CUDA) + { + auto ctx = MakeCUDACtx(0); + test_qdm_err("gbtree", &ctx); + test_qdm_err("dart", &ctx); + } +#endif // defined(XGBOOST_USE_CUDA) +} } // namespace xgboost diff --git a/tests/cpp/gbm/test_gbtree.cu b/tests/cpp/gbm/test_gbtree.cu index 2393bfabd0ff..7321be75e65c 100644 --- a/tests/cpp/gbm/test_gbtree.cu +++ b/tests/cpp/gbm/test_gbtree.cu @@ -61,7 +61,6 @@ void TestInplaceFallback(Context const* ctx) { learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits::quiet_NaN(), &out_predt, 0, 0); auto output = testing::internal::GetCapturedStderr(); - std::cout << "output:" << output << std::endl; ASSERT_NE(output.find("Falling back"), std::string::npos); // test when the contexts match diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 49ff5e4127aa..4f44b7b1e1c7 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -210,6 +210,16 @@ SimpleLCG::StateType SimpleLCG::Max() const { return max(); } // Make sure it's compile time constant. static_assert(SimpleLCG::max() - SimpleLCG::min()); +void RandomDataGenerator::GenerateLabels(std::shared_ptr p_fmat) const { + RandomDataGenerator{p_fmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense( + p_fmat->Info().labels.Data()); + CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_); + p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_); + if (device_ != Context::kCpuId) { + p_fmat->Info().labels.SetDevice(device_); + } +} + void RandomDataGenerator::GenerateDense(HostDeviceVector *out) const { xgboost::SimpleRealUniformDistribution dist(lower_, upper_); CHECK(out); @@ -363,8 +373,9 @@ void RandomDataGenerator::GenerateCSR( CHECK_EQ(columns->Size(), value->Size()); } -std::shared_ptr RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, - size_t classes) const { +[[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateDMatrix(bool with_label, + bool float_label, + size_t classes) const { HostDeviceVector data; HostDeviceVector rptrs; HostDeviceVector columns; @@ -406,10 +417,58 @@ std::shared_ptr RandomDataGenerator::GenerateDMatrix(bool with_label, b return out; } -std::shared_ptr RandomDataGenerator::GenerateQuantileDMatrix() { +[[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateSparsePageDMatrix( + std::string prefix, bool with_label) const { + CHECK_GE(this->rows_, this->n_batches_); + CHECK_GE(this->n_batches_, 1) + << "Must set the n_batches before generating an external memory DMatrix."; + std::unique_ptr iter; + if (device_ == Context::kCpuId) { + iter = std::make_unique(this->sparsity_, rows_, cols_, n_batches_); + } else { +#if defined(XGBOOST_USE_CUDA) + iter = std::make_unique(this->sparsity_, rows_, cols_, n_batches_); +#else + CHECK(iter); +#endif // defined(XGBOOST_USE_CUDA) + } + + std::unique_ptr dmat{ + DMatrix::Create(static_cast(iter.get()), iter->Proxy(), Reset, Next, + std::numeric_limits::quiet_NaN(), Context{}.Threads(), prefix)}; + + auto row_page_path = + data::MakeId(prefix, dynamic_cast(dmat.get())) + ".row.page"; + EXPECT_TRUE(FileExists(row_page_path)) << row_page_path; + + // Loop over the batches and count the number of pages + std::size_t batch_count = 0; + bst_row_t row_count = 0; + for (const auto& batch : dmat->GetBatches()) { + batch_count++; + row_count += batch.Size(); + CHECK_NE(batch.data.Size(), 0); + } + + EXPECT_EQ(batch_count, n_batches_); + EXPECT_EQ(row_count, dmat->Info().num_row_); + + if (with_label) { + RandomDataGenerator{dmat->Info().num_row_, this->n_targets_, 0.0f}.GenerateDense( + dmat->Info().labels.Data()); + CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_); + dmat->Info().labels.Reshape(this->rows_, this->n_targets_); + } + return dmat; +} + +std::shared_ptr RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) { NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1}; auto m = std::make_shared( &iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits::quiet_NaN(), 0, bins_); + if (with_label) { + this->GenerateLabels(m); + } return m; } diff --git a/tests/cpp/helpers.cu b/tests/cpp/helpers.cu index f72281cb4dbb..10b800fc1c82 100644 --- a/tests/cpp/helpers.cu +++ b/tests/cpp/helpers.cu @@ -24,10 +24,13 @@ int CudaArrayIterForTest::Next() { return 1; } -std::shared_ptr RandomDataGenerator::GenerateDeviceDMatrix() { +std::shared_ptr RandomDataGenerator::GenerateDeviceDMatrix(bool with_label) { CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1}; auto m = std::make_shared( &iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits::quiet_NaN(), 0, bins_); + if (with_label) { + this->GenerateLabels(m); + } return m; } } // namespace xgboost diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 035baf22a013..449d97a407df 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -238,15 +238,18 @@ class RandomDataGenerator { bst_target_t n_targets_{1}; std::int32_t device_{Context::kCpuId}; + std::size_t n_batches_{0}; std::uint64_t seed_{0}; SimpleLCG lcg_; - std::size_t bins_{0}; + bst_bin_t bins_{0}; std::vector ft_; bst_cat_t max_cat_; Json ArrayInterfaceImpl(HostDeviceVector* storage, size_t rows, size_t cols) const; + void GenerateLabels(std::shared_ptr p_fmat) const; + public: RandomDataGenerator(bst_row_t rows, size_t cols, float sparsity) : rows_{rows}, cols_{cols}, sparsity_{sparsity}, lcg_{seed_} {} @@ -263,12 +266,16 @@ class RandomDataGenerator { device_ = d; return *this; } + RandomDataGenerator& Batches(std::size_t n_batches) { + n_batches_ = n_batches; + return *this; + } RandomDataGenerator& Seed(uint64_t s) { seed_ = s; lcg_.Seed(seed_); return *this; } - RandomDataGenerator& Bins(size_t b) { + RandomDataGenerator& Bins(bst_bin_t b) { bins_ = b; return *this; } @@ -309,12 +316,17 @@ class RandomDataGenerator { void GenerateCSR(HostDeviceVector* value, HostDeviceVector* row_ptr, HostDeviceVector* columns) const; - std::shared_ptr GenerateDMatrix(bool with_label = false, bool float_label = true, - size_t classes = 1) const; + [[nodiscard]] std::shared_ptr GenerateDMatrix(bool with_label = false, + bool float_label = true, + size_t classes = 1) const; + + [[nodiscard]] std::shared_ptr GenerateSparsePageDMatrix(std::string prefix, + bool with_label) const; + #if defined(XGBOOST_USE_CUDA) - std::shared_ptr GenerateDeviceDMatrix(); + std::shared_ptr GenerateDeviceDMatrix(bool with_label); #endif - std::shared_ptr GenerateQuantileDMatrix(); + std::shared_ptr GenerateQuantileDMatrix(bool with_label); }; // Generate an empty DMatrix, mostly for its meta info. @@ -443,11 +455,11 @@ class ArrayIterForTest { size_t static constexpr Cols() { return 13; } public: - std::string AsArray() const { return interface_; } + [[nodiscard]] std::string AsArray() const { return interface_; } virtual int Next() = 0; virtual void Reset() { iter_ = 0; } - size_t Iter() const { return iter_; } + [[nodiscard]] std::size_t Iter() const { return iter_; } auto Proxy() -> decltype(proxy_) { return proxy_; } explicit ArrayIterForTest(float sparsity, size_t rows, size_t cols, size_t batches); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 087543cfe160..841a576d5a8a 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -216,7 +216,7 @@ void TestUpdatePredictionCache(bool use_subsampling) { TEST(CPUPredictor, GHistIndex) { size_t constexpr kRows{128}, kCols{16}, kBins{64}; - auto p_hist = RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).GenerateQuantileDMatrix(); + auto p_hist = RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).GenerateQuantileDMatrix(false); HostDeviceVector storage(kRows * kCols); auto columnar = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&storage); auto adapter = data::ArrayAdapter(columnar.c_str()); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 30fbaf997ffd..15fbd462e94a 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -123,7 +123,8 @@ TEST(GPUPredictor, EllpackBasic) { auto ctx = MakeCUDACtx(0); for (size_t bins = 2; bins < 258; bins += 16) { size_t rows = bins * 16; - auto p_m = RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix(); + auto p_m = + RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix(false); ASSERT_FALSE(p_m->PageExists()); TestPredictionFromGradientIndex(&ctx, rows, kCols, p_m); TestPredictionFromGradientIndex(&ctx, bins, kCols, p_m); @@ -133,7 +134,7 @@ TEST(GPUPredictor, EllpackBasic) { TEST(GPUPredictor, EllpackTraining) { size_t constexpr kRows { 128 }, kCols { 16 }, kBins { 64 }; auto p_ellpack = - RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).Device(0).GenerateDeviceDMatrix(); + RandomDataGenerator{kRows, kCols, 0.0}.Bins(kBins).Device(0).GenerateDeviceDMatrix(false); HostDeviceVector storage(kRows * kCols); auto columnar = RandomDataGenerator{kRows, kCols, 0.0} .Device(0) @@ -219,7 +220,7 @@ TEST(GPUPredictor, ShapStump) { gbm::GBTreeModel model(&mparam, &ctx); std::vector> trees; - trees.push_back(std::unique_ptr(new RegTree)); + trees.push_back(std::make_unique()); model.CommitModelGroup(std::move(trees), 0); auto gpu_lparam = MakeCUDACtx(0); @@ -246,7 +247,7 @@ TEST(GPUPredictor, Shap) { gbm::GBTreeModel model(&mparam, &ctx); std::vector> trees; - trees.push_back(std::unique_ptr(new RegTree)); + trees.push_back(std::make_unique()); trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0); model.CommitModelGroup(std::move(trees), 0);