Skip to content

Commit

Permalink
Implement secure horizontal scheme for federated learning (#10231)
Browse files Browse the repository at this point in the history
* Add additional data split mode to cover the secure vertical pipeline

* Add IsSecure info and update corresponding functions

* Modify evaluate_splits to block non-label owners to perform hist compute under secure scenario

* Continue using Allgather for best split sync for secure vertical, equvalent to broadcast

* Modify histogram sync scheme for secure vertical case, can identify global best split, but need to further apply split correctly

* Sync cut informaiton across clients, full pipeline works for testing case

* Code cleanup, phase 1 of alternative vertical pipeline finished

* Code clean

* change kColS to kColSecure to avoid confusion with kCols

* Replace allreduce with allgather, functional but inefficient version

* Update AllGather behavior from individual pair to bulk by adopting new histogram transmission data structure of a flat vector

* comment out the record printing

* fix pointer bug for histsync with allgather

* identify the HE adding locations

* revise and simplify template code

* revise and simplify template code

* prepare aggregator for gh broadcast

* prepare histogram for histindex and row index for allgather

* fix conflicts

* fix conflicts

* fix format

* fix allgather logic and update unit test

* fix linting

* fix linting and other unit test issues

* fix linting and other unit test issues

* integration with interface initial attempt

* integration with interface initial attempt

* integration with interface initial attempt

* functional integration with interface

* remove debugging prints

* remove processor from another PR

* Update the processor functions according to new processor implementation

* Move processor interface init from learner to communicator

* Move processor interface init from learner to communicator functional

* switch to allgatherV for encrypted message with varying lenghts

* consolidate with processor interface PR

* remove prints and fix format

* fix linting over reference pass

* fix undefined symbol issue

* fix processor test

* secure vertical relies on processor, move the unit test

* type correction

* type correction

* extra linting from last change

* Added Windows support

* fix for cstdint types

* fix for cstdint types

* Added support for horizontal secure XGBoost

* update with mock plugin

* secure horizontal fully functional with mock plugin

* linting fix

* linting fix

* linting fix

* fix type

* change loader and proc params input pattern to align with std map

* update with secure vertical incorporation

* Update mock_processor to enable nvflare usage

* [backport] Fix compiling with the latest CTX. (#10263)

* fix secure horizontal inference

* initialized aggr context only once

* Added support for multiple plugins in a single lib

* remove redundant condition

* Added support for boolean in proc_params

* free buffer

* CUDA.

* Fix clean build.

* Fix include.

* tidy.

* lint.

* nolint.

* disable.

* disable sanitizer.

---------

Co-authored-by: Zhihong Zhang <[email protected]>
Co-authored-by: Jiaming Yuan <[email protected]>
  • Loading branch information
3 people authored Jun 18, 2024
1 parent 09bc2c7 commit de4013f
Show file tree
Hide file tree
Showing 15 changed files with 832 additions and 79 deletions.
1 change: 1 addition & 0 deletions cmake/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ macro(xgboost_target_link_libraries target)
else()
target_link_libraries(${target} PRIVATE Threads::Threads ${CMAKE_THREAD_LIBS_INIT})
endif()
target_link_libraries(${target} PRIVATE ${CMAKE_DL_LIBS})

if(USE_OPENMP)
if(BUILD_STATIC_LIB)
Expand Down
10 changes: 5 additions & 5 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2, kRowSecure = 3 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -181,16 +181,16 @@ class MetaInfo {
void SynchronizeNumberOfColumns(Context const* ctx);

/*! \brief Whether the data is split row-wise. */
bool IsRowSplit() const {
return data_split_mode == DataSplitMode::kRow;
}
bool IsRowSplit() const { return (data_split_mode == DataSplitMode::kRow)
|| (data_split_mode == DataSplitMode::kRowSecure); }

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
bool IsSecure() const { return (data_split_mode == DataSplitMode::kColSecure)
|| (data_split_mode == DataSplitMode::kRowSecure); }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }
Expand Down
48 changes: 45 additions & 3 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
#include "../processing/processor.h" // for Processor

namespace xgboost::collective {

Expand Down Expand Up @@ -69,7 +70,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
template <bool is_gpair, typename T, typename Function>
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
if (info.IsVerticalFederated()) {
Expand All @@ -96,8 +97,49 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
}
collective::Broadcast(&size, sizeof(std::size_t), 0);

result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
if (info.IsSecure() && is_gpair) {
// Under secure mode, gpairs will be processed to vector and encrypt
// information only available on rank 0
std::size_t buffer_size{};
std::int8_t *buffer;
if (collective::GetRank() == 0) {
std::vector<double> vector_gh;
for (std::size_t i = 0; i < size; i++) {
auto gpair = result->HostVector()[i];
// cast from GradientPair to float pointer
auto gpair_ptr = reinterpret_cast<float*>(&gpair);
// save to vector
vector_gh.push_back(gpair_ptr[0]);
vector_gh.push_back(gpair_ptr[1]);
}
// provide the vectors to the processor interface
size_t size;
auto buf = processor_instance->ProcessGHPairs(&size, vector_gh);
buffer_size = size;
buffer = reinterpret_cast<std::int8_t *>(buf);
}

// broadcast the buffer size for other ranks to prepare
collective::Broadcast(&buffer_size, sizeof(std::size_t), 0);
// prepare buffer on passive parties for satisfying broadcast mpi call
if (collective::GetRank() != 0) {
buffer = reinterpret_cast<std::int8_t *>(malloc(buffer_size));
}

// broadcast the data buffer holding processed gpairs
collective::Broadcast(buffer, buffer_size, 0);

// call HandleGHPairs
size_t size;
processor_instance->HandleGHPairs(&size, buffer, buffer_size);

// free the buffer
free(buffer);
} else {
// clear text mode, broadcast the data directly
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
}
} else {
std::forward<Function>(function)();
}
Expand Down
72 changes: 60 additions & 12 deletions src/collective/communicator.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <map>
#include "communicator.h"

#include "comm.h"
Expand All @@ -9,14 +10,41 @@
#include "rabit_communicator.h"

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#include "../../plugin/federated/federated_communicator.h"
#endif

#include "../processing/processor.h"
processing::Processor *processor_instance;

namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
thread_local std::string Communicator::nccl_path_{};

std::map<std::string, std::string> JsonToMap(xgboost::Json const& config, std::string key) {
auto json_map = xgboost::OptionalArg<xgboost::Object>(config, key, xgboost::JsonObject::Map{});
std::map<std::string, std::string> params{};
for (auto entry : json_map) {
std::string text;
xgboost::Value* value = &(entry.second.GetValue());
if (value->Type() == xgboost::Value::ValueKind::kString) {
text = reinterpret_cast<xgboost::String *>(value)->GetString();
} else if (value->Type() == xgboost::Value::ValueKind::kInteger) {
auto num = reinterpret_cast<xgboost::Integer *>(value)->GetInteger();
text = std::to_string(num);
} else if (value->Type() == xgboost::Value::ValueKind::kNumber) {
auto num = reinterpret_cast<xgboost::Number *>(value)->GetNumber();
text = std::to_string(num);
} else if (value->Type() == xgboost::Value::ValueKind::kBoolean) {
text = reinterpret_cast<xgboost::Boolean *>(value)->GetBoolean() ? "true" : "false";
} else {
text = "Unsupported type";
}
params[entry.first] = text;
}
return params;
}

void Communicator::Init(Json const& config) {
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()});
nccl_path_ = nccl;
Expand All @@ -38,26 +66,46 @@ void Communicator::Init(Json const& config) {
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
communicator_.reset(FederatedCommunicator::Create(config));
// Get processor configs
std::string plugin_name{};
std::string loader_params_key{};
std::string loader_params_map{};
std::string proc_params_key{};
std::string proc_params_map{};
plugin_name = OptionalArg<String>(config, "plugin_name", plugin_name);
// Initialize processor if plugin_name is provided
if (!plugin_name.empty()) {
std::map<std::string, std::string> loader_params = JsonToMap(config, "loader_params");
std::map<std::string, std::string> proc_params = JsonToMap(config, "proc_params");
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.Load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);
}
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
break;
}

case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}

#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(new NoOpCommunicator());
if (processor_instance != nullptr) {
processor_instance->Shutdown();
processor_instance = nullptr;
}
}
#endif
} // namespace xgboost::collective
7 changes: 4 additions & 3 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1);
collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(),
collective::ApplyWithLabels<false>(this->Ctx(), info, base_score->Data(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
Expand Down Expand Up @@ -1472,8 +1472,9 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
// calculate gradient and communicate
collective::ApplyWithLabels<true>(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
}

/*! \brief random number transformation seed. */
Expand Down
2 changes: 1 addition & 1 deletion src/objective/adaptive.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
auto t_predt = d_predt.Slice(linalg::All(), group_idx);

HostDeviceVector<float> quantiles;
collective::ApplyWithLabels(ctx, info, &quantiles, [&] {
collective::ApplyWithLabels<false>(ctx, info, &quantiles, [&] {
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();
Expand Down
Loading

0 comments on commit de4013f

Please sign in to comment.