Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Federated plugin for histogram. #10534

Merged
merged 10 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc federated_plugin.cc
objxgboost PRIVATE
federated_plugin.cc federated_hist.cc federated_tracker.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
Expand Down
10 changes: 7 additions & 3 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "federated_coll.h"

Expand All @@ -8,11 +8,15 @@

#include <algorithm> // for copy_n

#include "../../src/collective/allgather.h"
#include "../../src/common/common.h" // for AssertGPUSupport
#include "federated_comm.h" // for FederatedComm
#include "xgboost/collective/result.h" // for Result

#if !defined(XGBOOST_USE_CUDA)

#include "../../src/common/common.h" // for AssertGPUSupport

#endif // !defined(XGBOOST_USE_CUDA)

namespace xgboost::collective {
namespace {
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
Expand Down
159 changes: 159 additions & 0 deletions plugin/federated/federated_hist.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "federated_hist.h"

#include "../../src/collective/allgather.h" // for AllgatherV
#include "../../src/collective/communicator-inl.h" // for GetRank
#include "../../src/tree/hist/histogram.h" // for SubtractHistParallel, BuildSampleHistograms

namespace xgboost::tree {
namespace {
// Copy the bins into a dense matrix.
auto CopyBinsToDense(Context const *ctx, GHistIndexMatrix const &gidx) {
auto n_samples = gidx.Size();
auto n_features = gidx.Features();
std::vector<bst_bin_t> bins(n_samples * n_features);
auto bins_view = linalg::MakeTensorView(ctx, bins, n_samples, n_features);
common::ParallelFor(n_samples, ctx->Threads(), [&](auto ridx) {
for (bst_feature_t fidx = 0; fidx < n_features; fidx++) {
bins_view(ridx, fidx) = gidx.GetGindex(ridx, fidx);
}
});
return bins;
}
} // namespace

template <bool any_missing>
void FederataedHistPolicy::DoBuildLocalHistograms(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *p_buffer) {
if (is_col_split_) {
// Copy the gidx information to the secure worker for encrypted histogram
// computation. This is copied as we don't want the plugin to handle the bin
// compression, which is quite internal of XGBoost.

// FIXME: this can be done during reset.
if (!is_gidx_initialized_) {
auto bins = CopyBinsToDense(ctx_, gidx);
auto cuts = gidx.Cuts().Ptrs();
plugin_->Reset(cuts, bins);
is_gidx_initialized_ = true;
}

// Share the row set collection without copy.
std::vector<std::uint64_t const *> ptrs(nodes_to_build.size());
std::vector<std::size_t> sizes(nodes_to_build.size());
std::vector<bst_node_t> nodes(nodes_to_build.size());
for (std::size_t i = 0; i < nodes_to_build.size(); ++i) {
auto nidx = nodes_to_build[i];
ptrs[i] = row_set_collection[nidx].begin();
sizes[i] = row_set_collection[nidx].Size();
nodes[i] = nidx;
}
hist_data_ = this->plugin_->BuildEncryptedHistVert(ptrs, sizes, nodes);
} else {
BuildSampleHistograms<any_missing>(this->ctx_->Threads(), space, gidx, nodes_to_build,
row_set_collection, gpair_h, force_read_by_column, p_buffer);
}
}

template void FederataedHistPolicy::DoBuildLocalHistograms<true>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);
template void FederataedHistPolicy::DoBuildLocalHistograms<false>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);

namespace {
// The label owner needs to gather the result from all workers.
void GatherWorkerHist(common::Span<double> hist_aggr, std::int32_t n_workers,
std::vector<bst_node_t> const &nodes_to_build, bst_bin_t n_total_bins,
tree::BoundedHistCollection *p_hist) {
bst_idx_t worker_size = hist_aggr.size() / n_workers;
bst_node_t n_nodes = nodes_to_build.size();
auto &hist = *p_hist;
// for each worker
for (auto widx = 0; widx < n_workers; ++widx) {
auto worker_hist = hist_aggr.subspan(widx * worker_size, worker_size);
// for each node
for (bst_node_t nidx_in_set = 0; nidx_in_set < n_nodes; ++nidx_in_set) {
auto hist_size = n_total_bins * kHist2F64; // Histogram size for one node.
auto hist_src = worker_hist.subspan(hist_size * nidx_in_set, hist_size);
auto hist_src_g = common::RestoreType<GradientPairPrecise>(hist_src);
auto hist_dst = hist[nodes_to_build[nidx_in_set]];
CHECK_EQ(hist_src_g.size(), hist_dst.size());
common::IncrementHist(hist_dst, hist_src_g, 0, hist_dst.size());
}
}
}
} // namespace

void FederataedHistPolicy::DoSyncHistogram(common::BlockedSpace2d const &space,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *p_buffer,
tree::BoundedHistCollection *p_hist) {
auto n_total_bins = p_buffer->TotalBins();
CHECK(!nodes_to_build.empty());

auto &hist = *p_hist;
if (is_col_split_) {
// Under secure vertical mode, we perform allgather to get the global histogram. Note
// that only the label owner (rank == 0) needs the global histogram

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx_, linalg::MakeVec(hist_data_), &recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin_->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Update histogram for the label owner
if (collective::GetRank() == 0) {
std::int32_t n_workers = collective::GetWorldSize();
CHECK_EQ(hist_aggr.size() % n_workers, 0);
// Initialize histogram. For the normal case, this is done by the parallel hist
// buffer. We should try to unify the code paths.
for (auto nidx : nodes_to_build) {
auto hist_dst = hist[nidx];
std::fill_n(hist_dst.data(), hist_dst.size(), GradientPairPrecise{});
}
GatherWorkerHist(hist_aggr, n_workers, nodes_to_build, n_total_bins, p_hist);
}
} else {
common::ParallelFor2d(space, this->ctx_->Threads(), [&](std::size_t node, common::Range1d r) {
// Merging histograms from each thread.
p_buffer->ReduceHist(node, r.begin(), r.end());
});
// Encrtyped mode, we need to call the plugin to perform encryption and decryption.
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * kHist2F64;
auto src_hist = common::Span{reinterpret_cast<double const *>(hist[first_nidx].data()), n};
auto hist_buf = plugin_->BuildEncryptedHistHori(src_hist);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc =
collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
collective::SafeColl(rc);

auto hist_aggr =
plugin_->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
// Assign the aggregated histogram back to the local histogram
auto hist_dst = reinterpret_cast<double *>(hist[first_nidx].data());
std::copy_n(hist_aggr.data(), hist_aggr.size(), hist_dst);
}
}
} // namespace xgboost::tree
58 changes: 58 additions & 0 deletions plugin/federated/federated_hist.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* Copyright 2024, XGBoost contributors
*/
#pragma once
#include <cstdint> // for int32_t
#include <vector> // for vector

#include "../../src/collective/comm_group.h" // for GlobalCommGroup
#include "../../src/common/hist_util.h" // for ParallelGHistBuilder
#include "../../src/common/row_set.h" // for RowSetCollection
#include "../../src/common/threading_utils.h" // for BlockedSpace2d
#include "../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../src/tree/hist/hist_cache.h" // for BoundedHistCollection
#include "federated_comm.h" // for FederatedComm
#include "xgboost/base.h" // for GradientPair
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree

namespace xgboost::tree {
/**
* @brief Federated histogram build policy
*/
class FederataedHistPolicy {
// fixme: duplicated code
bool is_col_split_{false};
bool is_distributed_{false};
decltype(std::declval<collective::FederatedComm>().EncryptionPlugin()) plugin_;
xgboost::common::Span<std::uint8_t> hist_data_;
// Only initialize the aggregation context once
bool is_gidx_initialized_{false};
Context const* ctx_;

public:
void DoReset(Context const *ctx, bool is_distributed, bool is_col_split) {
this->is_distributed_ = is_distributed;
CHECK(is_distributed);
this->ctx_ = ctx;
this->is_col_split_ = is_col_split;
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
plugin_ = fed.EncryptionPlugin();
CHECK(is_distributed_) << "Unreachable. Single node training can not be federated.";
}

template <bool any_missing>
void DoBuildLocalHistograms(common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column,
common::ParallelGHistBuilder *p_buffer);

void DoSyncHistogram(common::BlockedSpace2d const &space,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *p_buffer, tree::BoundedHistCollection *p_hist);
};
} // namespace xgboost::tree
6 changes: 6 additions & 0 deletions plugin/federated/federated_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
* - Build histogram for vertical federated learning.
* - Build histogram for horizontal federated learning.
*
* Since we don't require the plugin to have network capability, the synchronization is
* performed in XGBoost. As a result, the build procedure is divided into four steps,
* first we need to build a local histogram, then encrypt it with the plugin. Afterward,
* the control returns to XBGoost, which is responsible for synchronization. Lastly, the
* plugin will recieve the synchronization result and return the decrypted histogram.
*
* See below function prototypes for details. All prototypes are for C functions that are
* suitable for `dlopen`.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

#include <algorithm> // for min
#include <atomic> // for atomic
#include <cinttypes> // for uint32_t
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for make_unique
#include <vector>

Expand Down
4 changes: 2 additions & 2 deletions src/tree/hist/histogram.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 by XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "histogram.h"

Expand All @@ -10,7 +10,7 @@

#include "../../common/transform_iterator.h" // for MakeIndexTransformIter
#include "expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "xgboost/logging.h" // for CHECK_NE
#include "xgboost/logging.h" // for CHECK_EQ
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree

Expand Down
Loading
Loading