Skip to content

Commit

Permalink
[EM] Get quantile cuts from the extmem qdm. (#10860)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Sep 30, 2024
1 parent 8cf2f7a commit 92f1c48
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion doc/tutorials/external_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ in-core training is one additional data read when the data is dense.

To run experiments on these platforms, the open source `NVIDIA Linux driver
<https://developer.nvidia.com/blog/nvidia-transitions-fully-towards-open-source-gpu-kernel-modules/>`__
with version ``>=565.47`` is required.
with version ``>=565.47`` is required, it should come with CTK 12.7 and later versions.

**************
Best Practices
Expand Down
11 changes: 5 additions & 6 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def check_extmem_qdm(
cache="cache",
on_host=on_host,
)

Xy_it = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="Only the `hist`"):
booster_it = xgb.train(
Expand All @@ -227,12 +228,10 @@ def check_extmem_qdm(
Xy = xgb.QuantileDMatrix(it)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)

if device == "cpu":
# Get cuts from ellpack without CPU-GPU interpolation is not yet supported.
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])

predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)
Expand Down
5 changes: 3 additions & 2 deletions src/data/extmem_quantile_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ExtMemQuantileDMatrix::~ExtMemQuantileDMatrix() {
}

BatchSet<ExtSparsePage> ExtMemQuantileDMatrix::GetExtBatches(Context const *, BatchParam const &) {
LOG(FATAL) << "Not implemented";
LOG(FATAL) << "Not implemented for `ExtMemQuantileDMatrix`.";
auto begin_iter =
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
return BatchSet<ExtSparsePage>{begin_iter};
Expand Down Expand Up @@ -121,7 +121,8 @@ BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndex(Context const
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
}

CHECK(this->ghist_index_source_);
CHECK(this->ghist_index_source_)
<< "The `ExtMemQuantileDMatrix` is initialized using GPU data, cannot be used for CPU.";
this->ghist_index_source_->Reset(param);

if (!std::isnan(param.sparse_thresh) &&
Expand Down
3 changes: 2 additions & 1 deletion src/data/extmem_quantile_dmatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,

std::visit(
[this, param](auto &&ptr) {
CHECK(ptr);
CHECK(ptr)
<< "The `ExtMemQuantileDMatrix` is initialized using CPU data, cannot be used for GPU.";
ptr->Reset(param);
},
this->ellpack_page_source_);
Expand Down
4 changes: 3 additions & 1 deletion src/data/extmem_quantile_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
[[nodiscard]] bool EllpackExists() const override {
return std::visit([](auto &&v) { return static_cast<bool>(v); }, ellpack_page_source_);
}
[[nodiscard]] bool GHistIndexExists() const override { return true; }
[[nodiscard]] bool GHistIndexExists() const override {
return static_cast<bool>(ghist_index_source_);
}

[[nodiscard]] BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx,
BatchParam const &param) override;
Expand Down
4 changes: 2 additions & 2 deletions src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
bst_bin_t GHistIndexMatrix::GetGindex(size_t ridx, size_t fidx) const {
auto begin = RowIdx(ridx);
if (IsDense()) {
return static_cast<bst_bin_t>(index[begin + fidx]);
return static_cast<bst_bin_t>(this->index[begin + fidx]);
}
auto end = RowIdx(ridx + 1);
auto const& cut_ptrs = cut.Ptrs();
auto f_begin = cut_ptrs[fidx];
auto f_end = cut_ptrs[fidx + 1];
return BinarySearchBin(begin, end, index, f_begin, f_end);
return BinarySearchBin(begin, end, this->index, f_begin, f_end);
}

float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const {
Expand Down
20 changes: 19 additions & 1 deletion tests/python-gpu/test_gpu_data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,35 @@ def test_cpu_data_iterator() -> None:
strategies.booleans(),
)
@settings(deadline=None, max_examples=10, print_blob=True)
@pytest.mark.filterwarnings("ignore")
def test_extmem_qdm(
n_samples_per_batch: int, n_features: int, n_batches: int, on_host: bool
) -> None:
check_extmem_qdm(n_samples_per_batch, n_features, n_batches, "cuda", on_host)


@pytest.mark.filterwarnings("ignore")
def test_invalid_device_extmem_qdm() -> None:
it = tm.IteratorForTest(
*tm.make_batches(16, 4, 2, use_cupy=False), cache="cache", on_host=True
)
Xy = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="cannot be used for GPU"):
xgb.train({"device": "cuda"}, Xy)

it = tm.IteratorForTest(
*tm.make_batches(16, 4, 2, use_cupy=True), cache="cache", on_host=True
)
Xy = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="cannot be used for CPU"):
xgb.train({"device": "cpu"}, Xy)


def test_concat_pages() -> None:
it = tm.IteratorForTest(*tm.make_batches(64, 16, 4, use_cupy=True), cache=None)
Xy = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="can not be used with concatenated pages"):
booster = xgb.train(
xgb.train(
{
"device": "cuda",
"subsample": 0.5,
Expand Down

0 comments on commit 92f1c48

Please sign in to comment.