diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 49a2594e4a52..e293ac739298 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -361,21 +361,23 @@ void SketchContainerImpl::AllReduce( } template -void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin, +bool AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin, HistogramCuts *cuts, bool secure) { size_t required_cuts = std::min(summary.size, static_cast(max_bin)); // make a copy of required_cuts for mode selection size_t required_cuts_original = required_cuts; if (secure) { - // Sync the required_cuts across all workers + // sync the required_cuts across all workers collective::Allreduce(&required_cuts, 1); } + // add the cut points auto &cut_values = cuts->cut_values_.HostVector(); - // if empty column, fill the cut values with 0 + // if secure and empty column, fill the cut values with NaN if (secure && (required_cuts_original == 0)) { for (size_t i = 1; i < required_cuts; ++i) { - cut_values.push_back(0.0); + cut_values.push_back(std::numeric_limits::quiet_NaN()); } + return true; } else { // we use the min_value as the first (0th) element, hence starting from 1. for (size_t i = 1; i < required_cuts; ++i) { @@ -384,6 +386,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b cut_values.push_back(cpt); } } + return false; } } @@ -437,6 +440,7 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const for (size_t fid = 0; fid < reduced.size(); ++fid) { size_t max_num_bins = std::min(num_cuts[fid], max_bins_); // If vertical and secure mode, we need to sync the max_num_bins aross workers + // to create the same global number of cut point bins for easier future processing if (info.IsVerticalFederated() && info.IsSecure()) { collective::Allreduce(&max_num_bins, 1); } @@ -445,15 +449,20 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts)); } else { // use special AddCutPoint scheme for secure vertical federated learning - AddCutPoint(a, max_num_bins, p_cuts, info.IsSecure()); - // push a value that is greater than anything - const bst_float cpt = - (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; - // this must be bigger than last value in a scale - const bst_float last = cpt + (fabs(cpt) + 1e-5f); - p_cuts->cut_values_.HostVector().push_back(last); + bool is_nan = AddCutPoint(a, max_num_bins, p_cuts, info.IsSecure()); + // push a value that is greater than anything if the feature is not empty + // i.e. if the last value is not NaN + if (!is_nan) { + const bst_float cpt = + (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; + // this must be bigger than last value in a scale + const bst_float last = cpt + (fabs(cpt) + 1e-5f); + p_cuts->cut_values_.HostVector().push_back(last); + } else { + // if the feature is empty, push a NaN value + p_cuts->cut_values_.HostVector().push_back(std::numeric_limits::quiet_NaN()); + } } - // Ensure that every feature gets at least one quantile point CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits::max()); auto cut_size = static_cast(p_cuts->cut_values_.HostVector().size()); @@ -461,13 +470,6 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const p_cuts->cut_ptrs_.HostVector().push_back(cut_size); } - if (info.IsVerticalFederated() && info.IsSecure()) { - // cut values need to be synced across all workers via Allreduce - auto cut_val = p_cuts->cut_values_.HostVector().data(); - std::size_t n = p_cuts->cut_values_.HostVector().size(); - collective::Allreduce(cut_val, n); - } - p_cuts->SetCategorical(this->has_categorical_, max_cat); monitor_.Stop(__func__); } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 85cd5ab67176..edd52ba22fc3 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -303,22 +303,35 @@ class HistEvaluator { // forward enumeration: split at right bound of each bin loss_chg = static_cast(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum}, - GradStats{right_sum}) - - parent.root_gain); - split_pt = cut_val[i]; // not used for partition based - best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); + GradStats{right_sum}) - parent.root_gain); + if (!is_secure_) { + split_pt = cut_val[i]; // not used for partition based + best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); + } else { + // secure mode: record the best split point, rather than the actual value + // since it is not accessible at this point (active party finding best-split) + best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum); + } } else { // backward enumeration: split at left bound of each bin loss_chg = static_cast(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum}, - GradStats{left_sum}) - - parent.root_gain); - if (i == imin) { - split_pt = cut.MinValues()[fidx]; + GradStats{left_sum}) - parent.root_gain); + if (!is_secure_) { + if (i == imin) { + split_pt = cut.MinValues()[fidx]; + } else { + split_pt = cut_val[i - 1]; + } + best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum); } else { - split_pt = cut_val[i - 1]; + // secure mode: record the best split point, rather than the actual value + // since it is not accessible at this point (active party finding best-split) + if (i != imin) { + i = i - 1; + } + best.Update(loss_chg, fidx, i, d_step == -1, false, right_sum, left_sum); } - best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum); } } } @@ -352,7 +365,6 @@ class HistEvaluator { } auto evaluator = tree_evaluator_.GetEvaluator(); auto const &cut_ptrs = cut.Ptrs(); - // Under secure vertical setting, only the active party is able to evaluate the split // based on global histogram. Other parties will receive the final best split information // Hence the below computation is not performed by the passive parties @@ -417,6 +429,16 @@ class HistEvaluator { all_entries[worker * entries.size() + nidx_in_set].split); } } + if (is_secure_) { + // At this point, all the workers have the best splits for all the nodes + // and workers can recover the actual split value with the split index + // Note that after the recovery, different workers will hold different + // split_value: real value for feature owner, NaN for others + for (auto & entry : entries) { + auto cut_index = entry.split.split_value; + entry.split.split_value = cut.Values()[cut_index]; + } + } } } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 9fa1566ea130..f7170dd14d06 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -7,6 +7,7 @@ #include "../../../src/common/hist_util.h" #include "../../../src/data/adapter.h" +#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix #include "xgboost/context.h" namespace xgboost::common { @@ -296,6 +297,105 @@ TEST(Quantile, ColSplitSorted) { TestColSplitQuantile(kRows, kCols); } +namespace { +template +void DoTestColSplitQuantileSecure() { + Context ctx; + auto const world = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + size_t cols = 2; + size_t rows = 3; + + auto m = std::unique_ptr{[=]() { + std::vector data = {1, 1, 0.6, 0.4, 0.8}; + std::vector row_idx = {0, 2, 0, 1, 2}; + std::vector col_ptr = {0, 2, 5}; + data::CSCAdapter adapter(col_ptr.data(), row_idx.data(), data.data(), 2, 3); + std::unique_ptr dmat(new data::SimpleDMatrix( + &adapter, std::numeric_limits::quiet_NaN(), -1)); + EXPECT_EQ(dmat->Info().num_col_, cols); + EXPECT_EQ(dmat->Info().num_row_, rows); + EXPECT_EQ(dmat->Info().num_nonzero_, 5); + return dmat->SliceCol(world, rank); + }()}; + + std::vector column_size(cols, 0); + auto const slice_size = cols / world; + auto const slice_start = slice_size * rank; + auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size; + for (auto i = slice_start; i < slice_end; i++) { + column_size[i] = rows; + } + + auto const n_bins = 64; + + m->Info().data_split_mode = DataSplitMode::kColSecure; + // Generate cuts for distributed environment. + HistogramCuts distributed_cuts; + { + ContainerType sketch_distributed( + &ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false); + + std::vector hessian(rows, 1.0); + auto hess = Span{hessian}; + if (use_column) { + for (auto const& page : m->GetBatches(&ctx)) { + PushPage(&sketch_distributed, page, m->Info(), hess); + } + } else { + for (auto const& page : m->GetBatches(&ctx)) { + PushPage(&sketch_distributed, page, m->Info(), hess); + } + } + + sketch_distributed.MakeCuts(&ctx, m->Info(), &distributed_cuts); + } + + auto const& dptrs = distributed_cuts.Ptrs(); + auto const& dvals = distributed_cuts.Values(); + auto const& dmins = distributed_cuts.MinValues(); + std::vector expected_ptrs = {0, 1, 4}; + std::vector expected_vals = {2, 0, 0, 0}; + std::vector expected_mins = {-1e-5, 1e-5}; + if (rank == 1) { + expected_ptrs = {0, 1, 4}; + expected_vals = {0, 0.6, 0.8, 1.6}; + expected_mins = {1e-5, -1e-5}; + } + + EXPECT_EQ(dptrs.size(), expected_ptrs.size()); + for (size_t i = 0; i < expected_ptrs.size(); ++i) { + EXPECT_EQ(dptrs[i], expected_ptrs[i]) << "rank: " << rank << ", i: " << i; + } + + EXPECT_EQ(dvals.size(), expected_vals.size()); + for (size_t i = 0; i < expected_vals.size(); ++i) { + if (!std::isnan(dvals[i])) { + EXPECT_NEAR(dvals[i], expected_vals[i], 2e-2f) << "rank: " << rank << ", i: " << i; + } + } + + EXPECT_EQ(dmins.size(), expected_mins.size()); + for (size_t i = 0; i < expected_mins.size(); ++i) { + EXPECT_FLOAT_EQ(dmins[i], expected_mins[i]) << "rank: " << rank << ", i: " << i; + } +} + +template +void TestColSplitQuantileSecure() { + auto constexpr kWorkers = 2; + RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantileSecure); +} +} // anonymous namespace + +TEST(Quantile, ColSplitSecure) { + TestColSplitQuantileSecure(); +} + +TEST(Quantile, ColSplitSecureSorted) { + TestColSplitQuantileSecure(); +} + namespace { void TestSameOnAllWorkers() { auto const world = collective::GetWorldSize(); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 329379b5b4d6..8eab1cd9d987 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -289,4 +289,90 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) { GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()}, GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()}); } + +namespace { +void DoTestEvaluateSplitsSecure(bool force_read_by_column) { + Context ctx; + auto const world = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + int static constexpr kRows = 8, kCols = 16; + auto sampler = std::make_shared(1u); + + TrainParam param; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); + + auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); + auto m = dmat->SliceCol(world, rank); + m->Info().data_split_mode = DataSplitMode::kColSecure; + + auto evaluator = HistEvaluator{&ctx, ¶m, m->Info(), sampler}; + BoundedHistCollection hist; + std::vector row_gpairs = { + {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, + {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}}; + + size_t constexpr kMaxBins = 4; + // dense, no missing values + GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false); + common::RowSetCollection row_set_collection; + std::vector &row_indices = *row_set_collection.Data(); + row_indices.resize(kRows); + std::iota(row_indices.begin(), row_indices.end(), 0); + row_set_collection.Init(); + + HistMakerTrainParam hist_param; + hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node); + hist.AllocateHistograms({0}); + common::BuildHist(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column); + + // Compute total gradient for all data points + GradientPairPrecise total_gpair; + for (const auto &e : row_gpairs) { + total_gpair += GradientPairPrecise(e); + } + + RegTree tree; + std::vector entries(1); + entries.front().nid = 0; + entries.front().depth = 0; + + evaluator.InitRoot(GradStats{total_gpair}); + evaluator.EvaluateSplits(hist, gmat.cut, {}, tree, &entries); + + auto best_loss_chg = + evaluator.Evaluator().CalcSplitGain( + param, 0, entries.front().split.SplitIndex(), + entries.front().split.left_sum, entries.front().split.right_sum) - + evaluator.Stats().front().root_gain; + ASSERT_EQ(entries.front().split.loss_chg, best_loss_chg); + ASSERT_GT(entries.front().split.loss_chg, 16.2f); + + // Assert that's the best split + for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) { + GradStats left, right; + for (size_t j = gmat.cut.Ptrs()[i-1]; j < gmat.cut.Ptrs()[i]; ++j) { + auto loss_chg = + evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) - + evaluator.Stats().front().root_gain; + ASSERT_GE(best_loss_chg, loss_chg); + left.Add(hist[0][j].GetGrad(), hist[0][j].GetHess()); + right.SetSubstract(GradStats{total_gpair}, left); + } + } + + // Free memory allocated by the DMatrix + delete m; +} + +void TestEvaluateSplitsSecure (bool force_read_by_column) { + auto constexpr kWorkers = 2; + RunWithInMemoryCommunicator(kWorkers, DoTestEvaluateSplitsSecure, force_read_by_column); +} +} // anonymous namespace + +TEST(HistEvaluator, SecureEvaluate) { + TestEvaluateSplitsSecure(false); + TestEvaluateSplitsSecure(true); +} + } // namespace xgboost::tree