Skip to content

Commit

Permalink
[fed] Add federated plugin.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 15, 2024
1 parent 11563ca commit 296a941
Show file tree
Hide file tree
Showing 33 changed files with 1,236 additions and 259 deletions.
14 changes: 4 additions & 10 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };
enum class DataSplitMode : int { kRow = 0, kCol = 1 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -174,17 +174,11 @@ class MetaInfo {
*/
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
/** @brief Whether the data is split row-wise. */
[[nodiscard]] bool IsRowSplit() const { return data_split_mode == DataSplitMode::kRow; }

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
[[nodiscard]] bool IsColumnSplit() const { return !this->IsRowSplit(); }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }
Expand Down
5 changes: 5 additions & 0 deletions include/xgboost/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ auto MakeVec(T *ptr, size_t s, DeviceOrd device = DeviceOrd::CPU()) {
return linalg::TensorView<T, 1>{{ptr, s}, {s}, device};
}

template <typename T>
auto MakeVec(common::Span<T> data, DeviceOrd device = DeviceOrd::CPU()) {
return linalg::TensorView<T, 1>{data, {data.size()}, device};
}

template <typename T>
auto MakeVec(HostDeviceVector<T> *data) {
return MakeVec(data->Device().IsCUDA() ? data->DevicePointer() : data->HostPointer(),
Expand Down
3 changes: 2 additions & 1 deletion plugin/federated/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ target_link_libraries(federated_client INTERFACE federated_proto)

# Rabit engine for Federated Learning.
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc
objxgboost PRIVATE
federated_plugin.cc federated_hist.cc federated_tracker.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
Expand Down
10 changes: 7 additions & 3 deletions plugin/federated/federated_coll.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "federated_coll.h"

Expand All @@ -8,11 +8,15 @@

#include <algorithm> // for copy_n

#include "../../src/collective/allgather.h"
#include "../../src/common/common.h" // for AssertGPUSupport
#include "federated_comm.h" // for FederatedComm
#include "xgboost/collective/result.h" // for Result

#if !defined(XGBOOST_USE_CUDA)

#include "../../src/common/common.h" // for AssertGPUSupport

#endif // !defined(XGBOOST_USE_CUDA)

namespace xgboost::collective {
namespace {
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
Expand Down
10 changes: 9 additions & 1 deletion plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <limits> // for numeric_limits
#include <memory> // for make_shared
#include <string> // for string, stoi

#include "../../src/common/common.h" // for Split
Expand All @@ -32,7 +33,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
CHECK_LT(rank, world) << "Invalid worker rank.";

auto certs = {server_cert, client_cert, client_cert};
auto is_empty = [](auto const& s) { return s.empty(); };
auto is_empty = [](auto const& s) {
return s.empty();
};
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
std::none_of(certs.begin(), certs.end(), is_empty);
CHECK(valid) << "Invalid arguments for certificates.";
Expand Down Expand Up @@ -123,6 +126,11 @@ FederatedComm::FederatedComm(std::int32_t retry, std::chrono::seconds timeout, s
client_key = OptionalArg<String>(config, "federated_client_key_path", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert_path", client_cert);

/**
* Hist encryption plugin.
*/
this->plugin_.reset(CreateFederatedPlugin(config));

this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
Expand Down
8 changes: 7 additions & 1 deletion plugin/federated/federated_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include <memory> // for shared_ptr
#include <string> // for string

#include "../../src/collective/comm.h" // for HostComm
#include "../../src/collective/comm.h" // for HostComm
#include "federated_plugin.h" // for FederatedPlugin
#include "xgboost/json.h"

namespace xgboost::collective {
class FederatedComm : public HostComm {
std::shared_ptr<federated::Federated::Stub> stub_;
// Plugin for encryption
std::shared_ptr<FederatedPluginBase> plugin_{nullptr};

void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
Expand Down Expand Up @@ -62,6 +65,7 @@ class FederatedComm : public HostComm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] bool IsEncrypted() const override { return static_cast<bool>(plugin_); }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }

[[nodiscard]] Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
Expand All @@ -73,5 +77,7 @@ class FederatedComm : public HostComm {
*out = "rank:" + std::to_string(rank);
return Success();
};

auto EncryptionPlugin() const { return plugin_; }
};
} // namespace xgboost::collective
149 changes: 149 additions & 0 deletions plugin/federated/federated_hist.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/**
* Copyright 2024, XGBoost contributors
*/
#include "federated_hist.h"

#include "../../src/collective/allgather.h" // for AllgatherV
#include "../../src/collective/communicator-inl.h" // for GetRank
#include "../../src/tree/hist/histogram.h" // for SubtractHistParallel, BuildSampleHistograms

namespace xgboost::tree {
template <bool any_missing>
void FederataedHistPolicy::DoBuildLocalHistograms(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer) {
if (is_col_split_) {
// Call the interface to transmit gidx information to the secure worker for encrypted
// histogram computation
auto cuts = gidx.Cuts().Ptrs();
// fixme: this can be done during reset.
if (!is_aggr_context_initialized_) {
auto slots = std::vector<int>();
auto num_rows = gidx.Size();
for (std::size_t row = 0; row < num_rows; row++) {
for (std::size_t f = 0; f < cuts.size() - 1; f++) {
auto slot = gidx.GetGindex(row, f);
slots.push_back(slot);
}
}
plugin_->Reset(cuts, slots);
is_aggr_context_initialized_ = true;
}

// Further use the row set collection info to
// get the encrypted histogram from the secure worker
std::vector<std::uint64_t const *> ptrs(nodes_to_build.size());
std::vector<std::size_t> sizes(nodes_to_build.size());
std::vector<bst_node_t> nodes(nodes_to_build.size());
for (std::size_t i = 0; i < nodes_to_build.size(); ++i) {
auto nidx = nodes_to_build[i];
ptrs[i] = row_set_collection[nidx].begin();
sizes[i] = row_set_collection[nidx].Size();
nodes[i] = nidx;
}
hist_data_ = this->plugin_->BuildEncryptedHistVert(ptrs, sizes, nodes);
} else {
BuildSampleHistograms<any_missing>(this->n_threads_, space, gidx, nodes_to_build,
row_set_collection, gpair_h, force_read_by_column, buffer);
}
}

template void FederataedHistPolicy::DoBuildLocalHistograms<true>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);
template void FederataedHistPolicy::DoBuildLocalHistograms<false>(
common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection, common::Span<GradientPair const> gpair_h,
bool force_read_by_column, common::ParallelGHistBuilder *buffer);

void FederataedHistPolicy::DoSyncHistogram(Context const *ctx, RegTree const *p_tree,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *buffer,
tree::BoundedHistCollection *p_hist) {
auto n_total_bins = buffer->TotalBins();
common::BlockedSpace2d space(
nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024);
CHECK(!nodes_to_build.empty());

auto &hist = *p_hist;
if (is_col_split_) {
// Under secure vertical mode, we perform allgather to get the global histogram. Note
// that only the label owner (rank == 0) needs the global histogram

// Perform AllGather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data_), &recv_segments, &hist_entries));

// Call interface here to post-process the messages
common::Span<double> hist_aggr =
plugin_->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Update histogram for label owner
if (collective::GetRank() == 0) {
// iterator of the beginning of the vector
bst_node_t n_nodes = nodes_to_build.size();
std::int32_t n_workers = collective::GetWorldSize();
bst_idx_t worker_size = hist_aggr.size() / n_workers;
CHECK_EQ(hist_aggr.size() % n_workers, 0);
// Initialize histogram. For the normal case, this is done by the parallel hist
// buffer. We should try to unify the code paths.
for (auto nidx : nodes_to_build) {
auto hist_dst = hist[nidx];
std::fill_n(hist_dst.data(), hist_dst.size(), GradientPairPrecise{});
}

// for each worker
for (auto widx = 0; widx < n_workers; ++widx) {
auto worker_hist = hist_aggr.subspan(widx * worker_size, worker_size);
// for each node
for (bst_node_t nidx_in_set = 0; nidx_in_set < n_nodes; ++nidx_in_set) {
auto hist_src = worker_hist.subspan(n_total_bins * 2 * nidx_in_set, n_total_bins * 2);
auto hist_src_g = common::RestoreType<GradientPairPrecise>(hist_src);
auto hist_dst = hist[nodes_to_build[nidx_in_set]];
CHECK_EQ(hist_src_g.size(), hist_dst.size());
common::IncrementHist(hist_dst, hist_src_g, 0, hist_dst.size());
}
}
}
} else {
common::ParallelFor2d(space, this->n_threads_, [&](std::size_t node, common::Range1d r) {
// Merging histograms from each thread.
buffer->ReduceHist(node, r.begin(), r.end());
});
// Secure mode, we need to call interface to perform encryption and decryption
// note that the actual aggregation will be performed at server side
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
auto hist_to_aggr = std::vector<double>();
for (std::size_t hist_idx = 0; hist_idx < n; hist_idx++) {
double hist_item = reinterpret_cast<double *>(hist[first_nidx].data())[hist_idx];
hist_to_aggr.push_back(hist_item);
}
// ProcessHistograms
auto hist_buf = plugin_->BuildEncryptedHistHori(hist_to_aggr);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc = collective::AllgatherV(ctx, linalg::MakeVec(hist_buf), &recv_segments, &hist_entries);
collective::SafeColl(rc);

auto hist_aggr =
plugin_->SyncEncryptedHistHori(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
// Assign the aggregated histogram back to the local histogram
for (std::size_t hist_idx = 0; hist_idx < n; hist_idx++) {
reinterpret_cast<double *>(hist[first_nidx].data())[hist_idx] = hist_aggr[hist_idx];
}
}

SubtractHistParallel(ctx, space, p_tree, nodes_to_build, nodes_to_trick, buffer, p_hist);
}
} // namespace xgboost::tree
58 changes: 58 additions & 0 deletions plugin/federated/federated_hist.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/**
* Copyright 2024, XGBoost contributors
*/
#pragma once
#include <cstdint> // for int32_t
#include <vector> // for vector

#include "../../src/collective/comm_group.h" // for GlobalCommGroup
#include "../../src/common/hist_util.h" // for ParallelGHistBuilder
#include "../../src/common/row_set.h" // for RowSetCollection
#include "../../src/common/threading_utils.h" // for BlockedSpace2d
#include "../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../src/tree/hist/hist_cache.h" // for BoundedHistCollection
#include "federated_comm.h" // for FederatedComm
#include "xgboost/base.h" // for GradientPair
#include "xgboost/context.h" // for Context
#include "xgboost/span.h" // for Span
#include "xgboost/tree_model.h" // for RegTree

namespace xgboost::tree {
/**
* @brief Federated histogram build policy
*/
class FederataedHistPolicy {
// fixme: duplicated code
bool is_col_split_{false};
bool is_distributed_{false};
std::int32_t n_threads_{false};
decltype(std::declval<collective::FederatedComm>().EncryptionPlugin()) plugin_;
xgboost::common::Span<std::uint8_t> hist_data_;
// only initialize the aggregation context once
bool is_aggr_context_initialized_ = false; // fixme

public:
void Reset(Context const *ctx, bool is_distributed, bool is_col_split) {
this->is_distributed_ = is_distributed;
CHECK(is_distributed);
this->n_threads_ = ctx->Threads();
this->is_col_split_ = is_col_split;
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
plugin_ = fed.EncryptionPlugin();
CHECK(is_distributed_) << "Unreachable. Single node training can not be federated.";
}

template <bool any_missing>
void DoBuildLocalHistograms(common::BlockedSpace2d const &space, GHistIndexMatrix const &gidx,
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column,
common::ParallelGHistBuilder *buffer);

void DoSyncHistogram(Context const *ctx, RegTree const *p_tree,
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick,
common::ParallelGHistBuilder *buffer, tree::BoundedHistCollection *p_hist);
};
} // namespace xgboost::tree
Loading

0 comments on commit 296a941

Please sign in to comment.