Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement secure horizontal scheme for federated learning #10231

Merged
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
8570ba5
Add additional data split mode to cover the secure vertical pipeline
ZiyueXu77 Jan 31, 2024
2d00db6
Add IsSecure info and update corresponding functions
ZiyueXu77 Jan 31, 2024
ab17f5a
Modify evaluate_splits to block non-label owners to perform hist comp…
ZiyueXu77 Jan 31, 2024
fb1787c
Continue using Allgather for best split sync for secure vertical, equ…
ZiyueXu77 Feb 2, 2024
7a2a2b8
Modify histogram sync scheme for secure vertical case, can identify g…
ZiyueXu77 Feb 6, 2024
3ca3142
Sync cut informaiton across clients, full pipeline works for testing …
ZiyueXu77 Feb 7, 2024
22dd522
Code cleanup, phase 1 of alternative vertical pipeline finished
ZiyueXu77 Feb 8, 2024
52e8951
Code clean
ZiyueXu77 Feb 8, 2024
e9eef15
change kColS to kColSecure to avoid confusion with kCols
ZiyueXu77 Feb 12, 2024
91c8a2f
Replace allreduce with allgather, functional but inefficient version
ZiyueXu77 Feb 13, 2024
8340c26
Update AllGather behavior from individual pair to bulk by adopting ne…
ZiyueXu77 Feb 13, 2024
42a9df1
comment out the record printing
ZiyueXu77 Feb 13, 2024
41e5abb
fix pointer bug for histsync with allgather
ZiyueXu77 Feb 20, 2024
ea5dc98
Merge branch 'dmlc:master' into SecureBoostP2
ZiyueXu77 Feb 20, 2024
5d542f8
Merge branch 'dmlc:master' into SecureBoostP2
ZiyueXu77 Feb 23, 2024
d91be10
identify the HE adding locations
ZiyueXu77 Feb 23, 2024
dd60317
revise and simplify template code
ZiyueXu77 Mar 6, 2024
8da824c
revise and simplify template code
ZiyueXu77 Mar 6, 2024
fb9f4fa
prepare aggregator for gh broadcast
ZiyueXu77 Mar 13, 2024
e77f8c6
prepare histogram for histindex and row index for allgather
ZiyueXu77 Mar 14, 2024
7ef48c8
Merge branch 'vertical-federated-learning' into SecureBoostP2
ZiyueXu77 Mar 15, 2024
8405791
fix conflicts
ZiyueXu77 Mar 15, 2024
db7d518
fix conflicts
ZiyueXu77 Mar 15, 2024
dd6adde
fix format
ZiyueXu77 Mar 15, 2024
9567e67
fix allgather logic and update unit test
ZiyueXu77 Mar 19, 2024
53800f2
fix linting
ZiyueXu77 Mar 19, 2024
b7e70f1
fix linting and other unit test issues
ZiyueXu77 Mar 20, 2024
49e8fd6
fix linting and other unit test issues
ZiyueXu77 Mar 20, 2024
da0f7a6
integration with interface initial attempt
ZiyueXu77 Mar 22, 2024
406cda3
integration with interface initial attempt
ZiyueXu77 Mar 22, 2024
f6c63aa
integration with interface initial attempt
ZiyueXu77 Mar 22, 2024
f223df7
functional integration with interface
ZiyueXu77 Apr 1, 2024
d881d84
remove debugging prints
ZiyueXu77 Apr 1, 2024
2997cf7
remove processor from another PR
ZiyueXu77 Apr 1, 2024
3a1f9ac
Update the processor functions according to new processor implementation
ZiyueXu77 Apr 12, 2024
1107604
Move processor interface init from learner to communicator
ZiyueXu77 Apr 12, 2024
30b7ed5
Move processor interface init from learner to communicator functional
ZiyueXu77 Apr 12, 2024
a3ddf7d
switch to allgatherV for encrypted message with varying lenghts
ZiyueXu77 Apr 15, 2024
3123b51
consolidate with processor interface PR
ZiyueXu77 Apr 19, 2024
73225a0
remove prints and fix format
ZiyueXu77 Apr 23, 2024
e85b1fb
fix linting over reference pass
ZiyueXu77 Apr 24, 2024
57750b4
fix undefined symbol issue
ZiyueXu77 Apr 24, 2024
fa2665a
fix processor test
ZiyueXu77 Apr 24, 2024
87d2fdb
secure vertical relies on processor, move the unit test
ZiyueXu77 Apr 24, 2024
9941293
type correction
ZiyueXu77 Apr 24, 2024
dd4f440
type correction
ZiyueXu77 Apr 24, 2024
5b2dfe6
extra linting from last change
ZiyueXu77 Apr 24, 2024
80d3b89
Added Windows support
nvidianz Apr 24, 2024
184b67f
Merge pull request #4 from nvidianz/processor-windows-support
ZiyueXu77 Apr 25, 2024
3382707
fix for cstdint types
ZiyueXu77 Apr 25, 2024
2a8f19a
fix for cstdint types
ZiyueXu77 Apr 25, 2024
9ff2935
Added support for horizontal secure XGBoost
nvidianz Apr 25, 2024
38e9d3d
Merge pull request #5 from nvidianz/processor-horizontal-support
ZiyueXu77 Apr 25, 2024
38c176c
update with mock plugin
ZiyueXu77 Apr 26, 2024
a5ce92e
secure horizontal fully functional with mock plugin
ZiyueXu77 Apr 26, 2024
5e824ac
linting fix
ZiyueXu77 Apr 26, 2024
81db216
linting fix
ZiyueXu77 Apr 26, 2024
15c211a
linting fix
ZiyueXu77 Apr 26, 2024
f7341cd
fix type
ZiyueXu77 Apr 26, 2024
a579205
Merge branch 'vertical-federated-learning' into SecureHorizontal
ZiyueXu77 Apr 29, 2024
35d8c15
change loader and proc params input pattern to align with std map
ZiyueXu77 Apr 29, 2024
3d31905
update with secure vertical incorporation
ZiyueXu77 May 13, 2024
7f86787
Merge branch 'vertical-federated-learning' into SecureHorizontal
ZiyueXu77 May 16, 2024
bdcb6e2
Update mock_processor to enable nvflare usage
ZiyueXu77 May 28, 2024
a8205d3
[backport] Fix compiling with the latest CTX. (#10263)
trivialfis May 29, 2024
ae77f2d
Merge remote-tracking branch 'ZiyueXu77/SecureHorizontal' into Secure…
trivialfis May 29, 2024
cc13605
fix secure horizontal inference
ZiyueXu77 May 29, 2024
d7a6da6
initialized aggr context only once
ZiyueXu77 May 29, 2024
032b14d
Added support for multiple plugins in a single lib
nvidianz May 30, 2024
ea9b298
Merge pull request #7 from nvidianz/support-multi-processors
ZiyueXu77 May 30, 2024
7f3472e
remove redundant condition
ZiyueXu77 Jun 3, 2024
5569e78
Added support for boolean in proc_params
nvidianz Jun 4, 2024
71e578c
Merge pull request #10 from nvidianz/support-bool-params-2nd-try
ZiyueXu77 Jun 4, 2024
454f69d
free buffer
ZiyueXu77 Jun 11, 2024
e2f77e2
Merge branch 'vertical-federated-learning' into SecureHorizontal
trivialfis Jun 18, 2024
6e4a3fb
Merge branch 'vertical-federated-learning' into SecureHorizontal
trivialfis Jun 18, 2024
61c8f47
CUDA.
trivialfis Jun 18, 2024
05be1e8
Fix clean build.
trivialfis Jun 18, 2024
e0795cf
Fix include.
trivialfis Jun 18, 2024
fd4f331
tidy.
trivialfis Jun 18, 2024
074c63b
lint.
trivialfis Jun 18, 2024
2a8fd72
nolint.
trivialfis Jun 18, 2024
ac02279
disable.
trivialfis Jun 18, 2024
1ef69ea
disable sanitizer.
trivialfis Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2, kRowSecure = 3 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -181,16 +181,16 @@ class MetaInfo {
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
bool IsRowSplit() const { return (data_split_mode == DataSplitMode::kRow)
|| (data_split_mode == DataSplitMode::kRowSecure); }

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
bool IsSecure() const { return (data_split_mode == DataSplitMode::kColSecure)
|| (data_split_mode == DataSplitMode::kRowSecure); }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }
Expand Down
45 changes: 42 additions & 3 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
#include "../processing/processor.h" // for Processor

namespace xgboost::collective {

Expand Down Expand Up @@ -69,7 +70,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
template <bool is_gpair, typename T, typename Function>
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
if (info.IsVerticalFederated()) {
Expand All @@ -96,8 +97,46 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
}
collective::Broadcast(&size, sizeof(std::size_t), 0);

result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
if (info.IsSecure() && is_gpair) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why the vertical federated learning section is being modified for horizontal learning?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one adds additional horizontal functions to the vertical_P2 PR, so it in fact includes everything that PR has, since that PR has not been merged, this one shows all the vertical modifications,
"Modifictions added beyond #10124"

// Under secure mode, gpairs will be processed to vector and encrypt
// information only available on rank 0
std::size_t buffer_size{};
std::int8_t *buffer;
if (collective::GetRank() == 0) {
std::vector<double> vector_gh;
for (std::size_t i = 0; i < size; i++) {
auto gpair = result->HostVector()[i];
// cast from GradientPair to float pointer
auto gpair_ptr = reinterpret_cast<float*>(&gpair);
// save to vector
vector_gh.push_back(gpair_ptr[0]);
vector_gh.push_back(gpair_ptr[1]);
}
// provide the vectors to the processor interface
size_t size;
auto buf = processor_instance->ProcessGHPairs(&size, vector_gh);
buffer_size = size;
buffer = reinterpret_cast<std::int8_t *>(buf);
}

// broadcast the buffer size for other ranks to prepare
collective::Broadcast(&buffer_size, sizeof(std::size_t), 0);
// prepare buffer on passive parties for satisfying broadcast mpi call
if (collective::GetRank() != 0) {
buffer = reinterpret_cast<std::int8_t *>(malloc(buffer_size));
}

// broadcast the data buffer holding processed gpairs
collective::Broadcast(buffer, buffer_size, 0);

// call HandleGHPairs
size_t size;
processor_instance->HandleGHPairs(&size, buffer, buffer_size);
} else {
// clear text mode, broadcast the data directly
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
}
} else {
std::forward<Function>(function)();
}
Expand Down
70 changes: 58 additions & 12 deletions src/collective/communicator.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <map>
#include "communicator.h"

#include "comm.h"
Expand All @@ -9,14 +10,39 @@
#include "rabit_communicator.h"

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#include "../../plugin/federated/federated_communicator.h"
#endif

#include "../processing/processor.h"
processing::Processor *processor_instance;

namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};

std::map<std::string, std::string> json_to_map(xgboost::Json const& config, std::string key) {
auto json_map = xgboost::OptionalArg<xgboost::Object>(config, key, xgboost::JsonObject::Map{});
std::map<std::string, std::string> params{};
for (auto entry : json_map) {
std::string text;
xgboost::Value* value = &(entry.second.GetValue());
if (value->Type() == xgboost::Value::ValueKind::kString) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsA<String>

text = reinterpret_cast<xgboost::String *>(value)->GetString();
} else if (value->Type() == xgboost::Value::ValueKind::kInteger) {
auto num = reinterpret_cast<xgboost::Integer *>(value)->GetInteger();
text = std::to_string(num);
} else if (value->Type() == xgboost::Value::ValueKind::kNumber) {
auto num = reinterpret_cast<xgboost::Number *>(value)->GetNumber();
text = std::to_string(num);
} else {
text = "Unsupported type ";
}
params[entry.first] = text;
}
return params;
}

void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
Expand All @@ -38,26 +64,46 @@ void Communicator::Init(Json const& config) {
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
communicator_.reset(FederatedCommunicator::Create(config));
// Get processor configs
std::string plugin_name{};
std::string loader_params_key{};
std::string loader_params_map{};
std::string proc_params_key{};
std::string proc_params_map{};
plugin_name = OptionalArg<String>(config, "plugin_name", plugin_name);
// Initialize processor if plugin_name is provided
if (!plugin_name.empty()) {
std::map<std::string, std::string> loader_params = json_to_map(config, "loader_params");
std::map<std::string, std::string> proc_params = json_to_map(config, "proc_params");
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a document for the expected parameters?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
break;
}

case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}

#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
if (processor_instance != nullptr) {
processor_instance->Shutdown();
processor_instance = nullptr;
}
}
#endif
} // namespace xgboost::collective
2 changes: 1 addition & 1 deletion src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
auto ind = get_ind(t); // == 0 if element is from x
// x_counter, y_counter
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
return thrust::make_tuple(static_cast<uint64_t>(!ind), static_cast<uint64_t>(ind));
});

// Compute the index for both x and y (which of the element in a and b are used in each
Expand Down
8 changes: 4 additions & 4 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ struct WriteCompressedEllpackFunctor {

using Tuple = thrust::tuple<size_t, size_t, size_t>;
__device__ size_t operator()(Tuple out) {
auto e = batch.GetElement(out.get<2>());
auto e = batch.GetElement(thrust::get<2>(out));
if (is_valid(e)) {
// -1 because the scan is inclusive
size_t output_position =
accessor.row_stride * e.row_idx + out.get<1>() - 1;
accessor.row_stride * e.row_idx + thrust::get<1>(out) - 1;
uint32_t bin_idx = 0;
if (common::IsCat(feature_types, e.column_idx)) {
bin_idx = accessor.SearchBin<true>(e.value, e.column_idx);
Expand All @@ -192,8 +192,8 @@ template <typename Tuple>
struct TupleScanOp {
__device__ Tuple operator()(Tuple a, Tuple b) {
// Key equal
if (a.template get<0>() == b.template get<0>()) {
b.template get<1>() += a.template get<1>();
if (thrust::get<0>(a) == thrust::get<0>(b)) {
thrust::get<1>(b) += thrust::get<1>(a);
return b;
}
// Not equal
Expand Down
7 changes: 4 additions & 3 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1);
collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(),
collective::ApplyWithLabels<false>(this->Ctx(), info, base_score->Data(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
Expand Down Expand Up @@ -1472,8 +1472,9 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
// calculate gradient and communicate
collective::ApplyWithLabels<true>(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
}

/*! \brief random number transformation seed. */
Expand Down
Loading