Skip to content

Commit

Permalink
add test for missing columns
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou committed Sep 21, 2023
1 parent 3915bbb commit 89fb13e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
size_t column_size = std::max(static_cast<size_t>(1ul), this->Column(i).size());
if (IsCat(h_feature_types, i)) {
// column_size is the number of unique values in that feature.
if (!is_column_split) {
CheckMaxCat(max_values[i].value, column_size);
}
CheckMaxCat(max_values[i].value, column_size);
h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0.
} else {
h_out_columns_ptr.push_back(
Expand All @@ -683,14 +681,14 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
d_out_columns_ptr[column_id + 1] -
d_out_columns_ptr[column_id]);
idx -= d_out_columns_ptr[column_id];
if (in_column.empty()) {
if (in_column.size() == 0) {
// If the column is empty, we push a dummy value. It won't affect training as the
// column is empty, trees cannot split on it. This is just to be consistent with
// rest of the library.
if (idx == 0) {
d_min_values[column_id] = kRtEps;
out_column[0] = kRtEps;
assert(is_column_split || out_column.size() == 1);
assert(out_column.size() == 1);
}
return;
}
Expand Down
25 changes: 25 additions & 0 deletions tests/cpp/common/test_quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,31 @@ TEST(GPUQuantile, MultiMerge) {
});
}

TEST(GPUQuantile, MissingColumns) {
auto dmat = std::unique_ptr<DMatrix>{[=]() {
std::size_t constexpr kRows = 1000, kCols = 100;
auto sparsity = 0.5f;
std::vector<FeatureType> ft(kCols);
for (size_t i = 0; i < ft.size(); ++i) {
ft[i] = (i % 2 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical;
}
auto dmat = RandomDataGenerator{kRows, kCols, sparsity}
.Seed(0)
.Lower(.0f)
.Upper(1.0f)
.Type(ft)
.MaxCategory(13)
.GenerateDMatrix();
return dmat->SliceCol(2, 1);
}()};
dmat->Info().data_split_mode = DataSplitMode::kRow;

auto ctx = MakeCUDACtx(0);
std::size_t constexpr kBins = 64;
HistogramCuts cuts = common::DeviceSketch(&ctx, dmat.get(), kBins);
ASSERT_TRUE(cuts.HasCategorical());
}

namespace {
void TestAllReduceBasic() {
auto const world = collective::GetWorldSize();
Expand Down

0 comments on commit 89fb13e

Please sign in to comment.