Skip to content

Commit

Permalink
Fix ext mem test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 3, 2023
1 parent bd0ece2 commit 4b0832f
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/data/gradient_index_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
}

std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override {
CHECK_NE(page.index.Size(), 0) << "Empty page is not supported.";
std::size_t bytes = 0;
bytes += WriteHistogramCuts(page.cut, fo);
// indptr
Expand Down
10 changes: 4 additions & 6 deletions src/data/gradient_index_page_source.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include "gradient_index_page_source.h"

namespace xgboost {
namespace data {
namespace xgboost::data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
Expand All @@ -21,5 +20,4 @@ void GradientIndexPageSource::Fetch() {
this->WriteCache();
}
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data
6 changes: 4 additions & 2 deletions tests/cpp/data/test_gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
namespace xgboost::data {
TEST(GradientIndex, ExternalMemoryBaseRowID) {
Context ctx;
auto p_fmat =
RandomDataGenerator{4096, 256, 8}.Device(ctx.gpu_id).GenerateSparsePageDMatrix("cache", true);
auto p_fmat = RandomDataGenerator{4096, 256, 0.5}
.Device(ctx.gpu_id)
.Batches(8)
.GenerateSparsePageDMatrix("cache", true);

std::vector<size_t> base_rowids;
std::vector<float> hessian(p_fmat->Info().num_row_, 1);
Expand Down
3 changes: 2 additions & 1 deletion tests/cpp/gbm/test_gbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ TEST(GBTree, InplacePredictionError) {

auto test_ext_err = [&](std::string booster, Context const* ctx) {
std::shared_ptr<DMatrix> p_fmat =
RandomDataGenerator{n_samples, n_features, 0.5f}.GenerateSparsePageDMatrix("cache", true);
RandomDataGenerator{n_samples, n_features, 0.5f}.Batches(2).GenerateSparsePageDMatrix(
"cache", true);
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParam("booster", booster);
ConfigLearnerByCtx(ctx, learner.get());
Expand Down
5 changes: 4 additions & 1 deletion tests/cpp/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,8 @@ void RandomDataGenerator::GenerateCSR(
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateSparsePageDMatrix(
std::string prefix, bool with_label) const {
CHECK_GE(this->rows_, this->n_batches_);
CHECK_GE(this->n_batches_, 1)
<< "Must set the n_batches before generating an external memory DMatrix.";
std::unique_ptr<ArrayIterForTest> iter;
if (device_ == Context::kCpuId) {
iter = std::make_unique<NumpyArrayIterForTest>(this->sparsity_, rows_, cols_, n_batches_);
Expand All @@ -445,9 +447,10 @@ void RandomDataGenerator::GenerateCSR(
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
batch_count++;
row_count += batch.Size();
CHECK_NE(batch.data.Size(), 0);
}

EXPECT_GE(batch_count, n_batches_);
EXPECT_EQ(batch_count, n_batches_);
EXPECT_EQ(row_count, dmat->Info().num_row_);

if (with_label) {
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class RandomDataGenerator {
bst_target_t n_targets_{1};

std::int32_t device_{Context::kCpuId};
std::size_t n_batches_{1};
std::size_t n_batches_{0};
std::uint64_t seed_{0};
SimpleLCG lcg_;

Expand Down

0 comments on commit 4b0832f

Please sign in to comment.