Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve test coverage with predictor configuration. #9354

Merged
merged 3 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions include/xgboost/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
#include <xgboost/logging.h> // for CHECK_GE
#include <xgboost/parameter.h> // for XGBoostParameter

#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
#include <type_traits> // for invoke_result_t, is_same_v

namespace xgboost {

Expand Down Expand Up @@ -152,6 +153,25 @@ struct Context : public XGBoostParameter<Context> {
ctx.gpu_id = kCpuId;
return ctx;
}
/**
* @brief Call function based on the current device.
*/
template <typename CPUFn, typename CUDAFn>
decltype(auto) DispatchDevice(CPUFn&& cpu_fn, CUDAFn&& cuda_fn) const {
static_assert(std::is_same_v<std::invoke_result_t<CPUFn>, std::invoke_result_t<CUDAFn>>);
switch (this->Device().device) {
case DeviceOrd::kCPU:
return cpu_fn();
case DeviceOrd::kCUDA:
return cuda_fn();
default:
// Do not use the device name as this is likely an internal error, the name
// wouldn't be valid.
LOG(FATAL) << "Unknown device type:" << static_cast<std::int16_t>(this->Device().device);
break;
}
return std::invoke_result_t<CPUFn>();
}

// declare parameters
DMLC_DECLARE_PARAMETER(Context) {
Expand Down
14 changes: 6 additions & 8 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,22 @@
*/
#pragma once
#include <xgboost/base.h>
#include <xgboost/cache.h> // DMatrixCache
#include <xgboost/cache.h> // for DMatrixCache
#include <xgboost/context.h> // for Context
#include <xgboost/context.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>

#include <functional> // std::function
#include <memory>
#include <functional> // for function
#include <memory> // for shared_ptr
#include <string>
#include <thread> // for get_id
#include <utility> // for make_pair
#include <vector>

// Forward declarations
namespace xgboost {
namespace gbm {
namespace xgboost::gbm {
struct GBTreeModel;
} // namespace gbm
} // namespace xgboost
} // namespace xgboost::gbm

namespace xgboost {
/**
Expand Down
4 changes: 4 additions & 0 deletions src/common/error_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ inline void MaxFeatureSize(std::uint64_t n_features) {
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<bst_feature_t>::max() << " features or greater";
}

constexpr StringView InplacePredictProxy() {
return "Inplace predict accepts only DMatrixProxy as input.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_
1 change: 1 addition & 0 deletions src/data/gradient_index_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
}

std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override {
CHECK_NE(page.index.Size(), 0) << "Empty page is not supported.";
std::size_t bytes = 0;
bytes += WriteHistogramCuts(page.cut, fo);
// indptr
Expand Down
10 changes: 4 additions & 6 deletions src/data/gradient_index_page_source.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include "gradient_index_page_source.h"

namespace xgboost {
namespace data {
namespace xgboost::data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
Expand All @@ -21,5 +20,4 @@ void GradientIndexPageSource::Fetch() {
this->WriteCache();
}
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
73 changes: 44 additions & 29 deletions src/gbm/gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <vector>

#include "../common/common.h"
#include "../common/error_msg.h" // for UnknownDevice
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../common/timer.h"
Expand Down Expand Up @@ -542,6 +542,18 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
}
}

namespace {
inline void MismatchedDevices(Context const* booster, Context const* data) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << booster->DeviceName()
<< ", while the input data is on: " << data->DeviceName() << ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
)";
}
}; // namespace

void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
// dispatch to const function.
Expand All @@ -555,24 +567,26 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
MismatchedDevices(this->ctx_, p_m->Ctx());
CHECK_EQ(out_preds->version, 0);
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
return;
}

if (this->ctx_->IsCPU()) {
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else if (p_m->Ctx()->IsCUDA()) {
CHECK(this->gpu_predictor_);
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else {
LOG(FATAL) << error::UnknownDevice();
bool known_type = this->ctx_->DispatchDevice(
[&, begin = tree_begin, end = tree_end] {
return this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
},
[&, begin = tree_begin, end = tree_end] {
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
});
if (!known_type) {
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
CHECK(proxy) << error::InplacePredictProxy();
LOG(FATAL) << "Unknown data type for inplace prediction:" << proxy->Adapter().type().name();
}
}

Expand Down Expand Up @@ -808,11 +822,9 @@ class Dart : public GBTree {
auto n_groups = model_.learner_model_param->num_output_group;

if (ctx_->Device() != p_fmat->Ctx()->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
MismatchedDevices(ctx_, p_fmat->Ctx());
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
return;
Expand All @@ -825,28 +837,31 @@ class Dart : public GBTree {
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);

auto get_predictor = [&]() -> Predictor const* {
if (ctx_->IsCPU()) {
return cpu_predictor_.get();
} else if (ctx_->IsCUDA()) {
CHECK(this->gpu_predictor_);
return gpu_predictor_.get();
} else {
LOG(FATAL) << error::UnknownDevice();
return nullptr;
}
};
auto predict_impl = [&](size_t i) {
predts.predictions.Fill(0);
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
bool success = this->ctx_->DispatchDevice(
[&] {
return cpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
},
[&] {
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
});
CHECK(success) << msg;
};

// Inplace predict is not used for training, so no need to drop tree.
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
predict_impl(i);
if (i == tree_begin) {
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
this->ctx_->DispatchDevice(
[&] {
this->cpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
},
[&] {
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
});
}
// Multiple the tree weight
auto w = this->weight_drop_.at(i);
Expand Down
3 changes: 2 additions & 1 deletion src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "../common/bitfield.h" // for RBitField8
#include "../common/categorical.h" // for IsCat, Decision
#include "../common/common.h" // for DivRoundUp
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../common/math.h" // for CheckNAN
#include "../common/threading_utils.h" // for ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter
Expand Down Expand Up @@ -741,7 +742,7 @@ class CPUPredictor : public Predictor {
PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(proxy)<< error::InplacePredictProxy();
CHECK(!p_m->Info().IsColumnSplit())
<< "Inplace predict support for column-wise data split is not yet implemented.";
auto x = proxy->Adapter();
Expand Down
5 changes: 3 additions & 2 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
#include "../common/cuda_context.cuh"
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../data/device_adapter.cuh"
#include "../data/ellpack_page.cuh"
#include "../data/proxy_dmatrix.h"
Expand Down Expand Up @@ -989,7 +990,7 @@ class GPUPredictor : public xgboost::Predictor {
PredictionCacheEntry* out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy*>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(proxy) << error::InplacePredictProxy();
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict<data::CupyAdapter,
Expand Down
21 changes: 13 additions & 8 deletions tests/cpp/data/test_gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,31 @@
#include "xgboost/host_device_vector.h" // for HostDeviceVector

namespace xgboost::data {
TEST(GradientIndex, ExternalMemory) {
TEST(GradientIndex, ExternalMemoryBaseRowID) {
Context ctx;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(10000);
auto p_fmat = RandomDataGenerator{4096, 256, 0.5}
.Device(ctx.gpu_id)
.Batches(8)
.GenerateSparsePageDMatrix("cache", true);

std::vector<size_t> base_rowids;
std::vector<float> hessian(dmat->Info().num_row_, 1);
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
std::vector<float> hessian(p_fmat->Info().num_row_, 1);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, true})) {
base_rowids.push_back(page.base_rowid);
}
size_t i = 0;
for (auto const &page : dmat->GetBatches<SparsePage>()) {

std::size_t i = 0;
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
ASSERT_EQ(base_rowids[i], page.base_rowid);
++i;
}

base_rowids.clear();
for (auto const &page : dmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, {64, hessian, false})) {
base_rowids.push_back(page.base_rowid);
}
i = 0;
for (auto const &page : dmat->GetBatches<SparsePage>()) {
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
ASSERT_EQ(base_rowids[i], page.base_rowid);
++i;
}
Expand Down
10 changes: 6 additions & 4 deletions tests/cpp/data/test_sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ TEST(SparsePageDMatrix, LoadFile) {
// allow caller to retain pages so they can process multiple pages at the same time.
template <typename Page>
void TestRetainPage() {
auto m = CreateSparsePageDMatrix(10000);
std::size_t n_batches = 4;
auto p_fmat = RandomDataGenerator{1024, 128, 0.5f}.Batches(n_batches).GenerateSparsePageDMatrix(
"cache", true);
Context ctx;
auto batches = m->GetBatches<Page>(&ctx);
auto batches = p_fmat->GetBatches<Page>(&ctx);
auto begin = batches.begin();
auto end = batches.end();

Expand All @@ -94,15 +96,15 @@ void TestRetainPage() {
}
ASSERT_EQ(pages.back().Size(), (*it).Size());
}
ASSERT_GE(iterators.size(), 2);
ASSERT_GE(iterators.size(), n_batches);

for (size_t i = 0; i < iterators.size(); ++i) {
ASSERT_EQ((*iterators[i]).Size(), pages.at(i).Size());
ASSERT_EQ((*iterators[i]).data.HostVector(), pages.at(i).data.HostVector());
}

// make sure it's const and the caller can not modify the content of page.
for (auto &page : m->GetBatches<Page>({&ctx})) {
for (auto &page : p_fmat->GetBatches<Page>({&ctx})) {
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
}
}
Expand Down
Loading
Loading