Skip to content

Commit

Permalink
Implement secure boost scheme - secure evaluation and validation (dur…
Browse files Browse the repository at this point in the history
…ing training) without local feature leakage (#10079)
  • Loading branch information
ZiyueXu77 authored May 16, 2024
1 parent 9ecfc84 commit 8585df5
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 30 deletions.
40 changes: 21 additions & 19 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,21 +361,23 @@ void SketchContainerImpl<WQSketch>::AllReduce(
}

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
bool AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts, bool secure) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// Sync the required_cuts across all workers
// sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);
}
// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// if empty column, fill the cut values with 0
// if secure and empty column, fill the cut values with NaN
if (secure && (required_cuts_original == 0)) {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
cut_values.push_back(std::numeric_limits<double>::quiet_NaN());
}
return true;
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
Expand All @@ -384,6 +386,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
cut_values.push_back(cpt);
}
}
return false;
}
}

Expand Down Expand Up @@ -437,6 +440,7 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (info.IsVerticalFederated() && info.IsSecure()) {
collective::Allreduce<collective::Operation::kMax>(&max_num_bins, 1);
}
Expand All @@ -445,29 +449,27 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
// use special AddCutPoint scheme for secure vertical federated learning
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
bool is_nan = AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!is_nan) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
} else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}

// Ensure that every feature gets at least one quantile point
CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(p_cuts->cut_values_.HostVector().size());
CHECK_GT(cut_size, p_cuts->cut_ptrs_.HostVector().back());
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
44 changes: 33 additions & 11 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,22 +303,35 @@ class HistEvaluator {
// forward enumeration: split at right bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
GradStats{right_sum}) - parent.root_gain);
if (!is_secure_) {
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
// secure mode: record the best split point, rather than the actual value
// since it is not accessible at this point (active party finding best-split)
best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum);
}
} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
if (i == imin) {
split_pt = cut.MinValues()[fidx];
GradStats{left_sum}) - parent.root_gain);
if (!is_secure_) {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
} else {
split_pt = cut_val[i - 1];
// secure mode: record the best split point, rather than the actual value
// since it is not accessible at this point (active party finding best-split)
if (i != imin) {
i = i - 1;
}
best.Update(loss_chg, fidx, i, d_step == -1, false, right_sum, left_sum);
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
}
Expand Down Expand Up @@ -352,7 +365,6 @@ class HistEvaluator {
}
auto evaluator = tree_evaluator_.GetEvaluator();
auto const &cut_ptrs = cut.Ptrs();

// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
Expand Down Expand Up @@ -417,6 +429,16 @@ class HistEvaluator {
all_entries[worker * entries.size() + nidx_in_set].split);
}
}
if (is_secure_) {
// At this point, all the workers have the best splits for all the nodes
// and workers can recover the actual split value with the split index
// Note that after the recovery, different workers will hold different
// split_value: real value for feature owner, NaN for others
for (auto & entry : entries) {
auto cut_index = entry.split.split_value;
entry.split.split_value = cut.Values()[cut_index];
}
}
}
}

Expand Down
100 changes: 100 additions & 0 deletions tests/cpp/common/test_quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "../../../src/common/hist_util.h"
#include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix
#include "xgboost/context.h"

namespace xgboost::common {
Expand Down Expand Up @@ -296,6 +297,105 @@ TEST(Quantile, ColSplitSorted) {
TestColSplitQuantile<true>(kRows, kCols);
}

namespace {
template <bool use_column>
void DoTestColSplitQuantileSecure() {
Context ctx;
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
size_t cols = 2;
size_t rows = 3;

auto m = std::unique_ptr<DMatrix>{[=]() {
std::vector<float> data = {1, 1, 0.6, 0.4, 0.8};
std::vector<unsigned> row_idx = {0, 2, 0, 1, 2};
std::vector<size_t> col_ptr = {0, 2, 5};
data::CSCAdapter adapter(col_ptr.data(), row_idx.data(), data.data(), 2, 3);
std::unique_ptr<data::SimpleDMatrix> dmat(new data::SimpleDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1));
EXPECT_EQ(dmat->Info().num_col_, cols);
EXPECT_EQ(dmat->Info().num_row_, rows);
EXPECT_EQ(dmat->Info().num_nonzero_, 5);
return dmat->SliceCol(world, rank);
}()};

std::vector<bst_row_t> column_size(cols, 0);
auto const slice_size = cols / world;
auto const slice_start = slice_size * rank;
auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size;
for (auto i = slice_start; i < slice_end; i++) {
column_size[i] = rows;
}

auto const n_bins = 64;

m->Info().data_split_mode = DataSplitMode::kColSecure;
// Generate cuts for distributed environment.
HistogramCuts distributed_cuts;
{
ContainerType<use_column> sketch_distributed(
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);

std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
} else {
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
}

sketch_distributed.MakeCuts(&ctx, m->Info(), &distributed_cuts);
}

auto const& dptrs = distributed_cuts.Ptrs();
auto const& dvals = distributed_cuts.Values();
auto const& dmins = distributed_cuts.MinValues();
std::vector<float> expected_ptrs = {0, 1, 4};
std::vector<float> expected_vals = {2, 0, 0, 0};
std::vector<float> expected_mins = {-1e-5, 1e-5};
if (rank == 1) {
expected_ptrs = {0, 1, 4};
expected_vals = {0, 0.6, 0.8, 1.6};
expected_mins = {1e-5, -1e-5};
}

EXPECT_EQ(dptrs.size(), expected_ptrs.size());
for (size_t i = 0; i < expected_ptrs.size(); ++i) {
EXPECT_EQ(dptrs[i], expected_ptrs[i]) << "rank: " << rank << ", i: " << i;
}

EXPECT_EQ(dvals.size(), expected_vals.size());
for (size_t i = 0; i < expected_vals.size(); ++i) {
if (!std::isnan(dvals[i])) {
EXPECT_NEAR(dvals[i], expected_vals[i], 2e-2f) << "rank: " << rank << ", i: " << i;
}
}

EXPECT_EQ(dmins.size(), expected_mins.size());
for (size_t i = 0; i < expected_mins.size(); ++i) {
EXPECT_FLOAT_EQ(dmins[i], expected_mins[i]) << "rank: " << rank << ", i: " << i;
}
}

template <bool use_column>
void TestColSplitQuantileSecure() {
auto constexpr kWorkers = 2;
RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantileSecure<use_column>);
}
} // anonymous namespace

TEST(Quantile, ColSplitSecure) {
TestColSplitQuantileSecure<false>();
}

TEST(Quantile, ColSplitSecureSorted) {
TestColSplitQuantileSecure<true>();
}

namespace {
void TestSameOnAllWorkers() {
auto const world = collective::GetWorldSize();
Expand Down
86 changes: 86 additions & 0 deletions tests/cpp/tree/hist/test_evaluate_splits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,4 +289,90 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
}

namespace {
void DoTestEvaluateSplitsSecure(bool force_read_by_column) {
Context ctx;
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
int static constexpr kRows = 8, kCols = 16;
auto sampler = std::make_shared<common::ColumnSampler>(1u);

TrainParam param;
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});

auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
auto m = dmat->SliceCol(world, rank);
m->Info().data_split_mode = DataSplitMode::kColSecure;

auto evaluator = HistEvaluator{&ctx, &param, m->Info(), sampler};
BoundedHistCollection hist;
std::vector<GradientPair> row_gpairs = {
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}};

size_t constexpr kMaxBins = 4;
// dense, no missing values
GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false);
common::RowSetCollection row_set_collection;
std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows);
std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init();

HistMakerTrainParam hist_param;
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
hist.AllocateHistograms({0});
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column);

// Compute total gradient for all data points
GradientPairPrecise total_gpair;
for (const auto &e : row_gpairs) {
total_gpair += GradientPairPrecise(e);
}

RegTree tree;
std::vector<CPUExpandEntry> entries(1);
entries.front().nid = 0;
entries.front().depth = 0;

evaluator.InitRoot(GradStats{total_gpair});
evaluator.EvaluateSplits(hist, gmat.cut, {}, tree, &entries);

auto best_loss_chg =
evaluator.Evaluator().CalcSplitGain(
param, 0, entries.front().split.SplitIndex(),
entries.front().split.left_sum, entries.front().split.right_sum) -
evaluator.Stats().front().root_gain;
ASSERT_EQ(entries.front().split.loss_chg, best_loss_chg);
ASSERT_GT(entries.front().split.loss_chg, 16.2f);

// Assert that's the best split
for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) {
GradStats left, right;
for (size_t j = gmat.cut.Ptrs()[i-1]; j < gmat.cut.Ptrs()[i]; ++j) {
auto loss_chg =
evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) -
evaluator.Stats().front().root_gain;
ASSERT_GE(best_loss_chg, loss_chg);
left.Add(hist[0][j].GetGrad(), hist[0][j].GetHess());
right.SetSubstract(GradStats{total_gpair}, left);
}
}

// Free memory allocated by the DMatrix
delete m;
}

void TestEvaluateSplitsSecure (bool force_read_by_column) {
auto constexpr kWorkers = 2;
RunWithInMemoryCommunicator(kWorkers, DoTestEvaluateSplitsSecure, force_read_by_column);
}
} // anonymous namespace

TEST(HistEvaluator, SecureEvaluate) {
TestEvaluateSplitsSecure(false);
TestEvaluateSplitsSecure(true);
}

} // namespace xgboost::tree

0 comments on commit 8585df5

Please sign in to comment.