-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement secure horizontal scheme for federated learning #10231
Changes from 63 commits
8570ba5
2d00db6
ab17f5a
fb1787c
7a2a2b8
3ca3142
22dd522
52e8951
e9eef15
91c8a2f
8340c26
42a9df1
41e5abb
ea5dc98
5d542f8
d91be10
dd60317
8da824c
fb9f4fa
e77f8c6
7ef48c8
8405791
db7d518
dd6adde
9567e67
53800f2
b7e70f1
49e8fd6
da0f7a6
406cda3
f6c63aa
f223df7
d881d84
2997cf7
3a1f9ac
1107604
30b7ed5
a3ddf7d
3123b51
73225a0
e85b1fb
57750b4
fa2665a
87d2fdb
9941293
dd4f440
5b2dfe6
80d3b89
184b67f
3382707
2a8f19a
9ff2935
38e9d3d
38c176c
a5ce92e
5e824ac
81db216
15c211a
f7341cd
a579205
35d8c15
3d31905
7f86787
bdcb6e2
a8205d3
ae77f2d
cc13605
d7a6da6
032b14d
ea9b298
7f3472e
5569e78
71e578c
454f69d
e2f77e2
6e4a3fb
61c8f47
05be1e8
e0795cf
fd4f331
074c63b
2a8fd72
ac02279
1ef69ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
/*! | ||
* Copyright 2022 XGBoost contributors | ||
*/ | ||
#include <map> | ||
#include "communicator.h" | ||
|
||
#include "comm.h" | ||
|
@@ -9,14 +10,39 @@ | |
#include "rabit_communicator.h" | ||
|
||
#if defined(XGBOOST_USE_FEDERATED) | ||
#include "../../plugin/federated/federated_communicator.h" | ||
#include "../../plugin/federated/federated_communicator.h" | ||
#endif | ||
|
||
#include "../processing/processor.h" | ||
processing::Processor *processor_instance; | ||
|
||
namespace xgboost::collective { | ||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()}; | ||
thread_local CommunicatorType Communicator::type_{}; | ||
thread_local std::string Communicator::nccl_path_{}; | ||
|
||
std::map<std::string, std::string> json_to_map(xgboost::Json const& config, std::string key) { | ||
auto json_map = xgboost::OptionalArg<xgboost::Object>(config, key, xgboost::JsonObject::Map{}); | ||
std::map<std::string, std::string> params{}; | ||
for (auto entry : json_map) { | ||
std::string text; | ||
xgboost::Value* value = &(entry.second.GetValue()); | ||
if (value->Type() == xgboost::Value::ValueKind::kString) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
text = reinterpret_cast<xgboost::String *>(value)->GetString(); | ||
} else if (value->Type() == xgboost::Value::ValueKind::kInteger) { | ||
auto num = reinterpret_cast<xgboost::Integer *>(value)->GetInteger(); | ||
text = std::to_string(num); | ||
} else if (value->Type() == xgboost::Value::ValueKind::kNumber) { | ||
auto num = reinterpret_cast<xgboost::Number *>(value)->GetNumber(); | ||
text = std::to_string(num); | ||
} else { | ||
text = "Unsupported type "; | ||
} | ||
params[entry.first] = text; | ||
} | ||
return params; | ||
} | ||
|
||
void Communicator::Init(Json const& config) { | ||
auto nccl = OptionalArg<String>(config, "dmlc_nccl_path", std::string{DefaultNcclName()}); | ||
nccl_path_ = nccl; | ||
|
@@ -38,26 +64,46 @@ void Communicator::Init(Json const& config) { | |
} | ||
case CommunicatorType::kFederated: { | ||
#if defined(XGBOOST_USE_FEDERATED) | ||
communicator_.reset(FederatedCommunicator::Create(config)); | ||
communicator_.reset(FederatedCommunicator::Create(config)); | ||
// Get processor configs | ||
std::string plugin_name{}; | ||
std::string loader_params_key{}; | ||
std::string loader_params_map{}; | ||
std::string proc_params_key{}; | ||
std::string proc_params_map{}; | ||
plugin_name = OptionalArg<String>(config, "plugin_name", plugin_name); | ||
// Initialize processor if plugin_name is provided | ||
if (!plugin_name.empty()) { | ||
std::map<std::string, std::string> loader_params = json_to_map(config, "loader_params"); | ||
std::map<std::string, std::string> proc_params = json_to_map(config, "proc_params"); | ||
processing::ProcessorLoader loader(loader_params); | ||
processor_instance = loader.load(plugin_name); | ||
processor_instance->Initialize(collective::GetRank() == 0, proc_params); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a document for the expected parameters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
#else | ||
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; | ||
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support."; | ||
#endif | ||
break; | ||
} | ||
case CommunicatorType::kInMemory: | ||
case CommunicatorType::kInMemoryNccl: { | ||
communicator_.reset(InMemoryCommunicator::Create(config)); | ||
break; | ||
} | ||
case CommunicatorType::kUnknown: | ||
LOG(FATAL) << "Unknown communicator type."; | ||
break; | ||
} | ||
|
||
case CommunicatorType::kInMemory: | ||
case CommunicatorType::kInMemoryNccl: { | ||
communicator_.reset(InMemoryCommunicator::Create(config)); | ||
break; | ||
} | ||
case CommunicatorType::kUnknown: | ||
LOG(FATAL) << "Unknown communicator type."; | ||
} | ||
} | ||
|
||
#ifndef XGBOOST_USE_CUDA | ||
void Communicator::Finalize() { | ||
communicator_->Shutdown(); | ||
communicator_.reset(new NoOpCommunicator()); | ||
if (processor_instance != nullptr) { | ||
processor_instance->Shutdown(); | ||
processor_instance = nullptr; | ||
} | ||
} | ||
#endif | ||
} // namespace xgboost::collective |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#include <iostream> | ||
#include <cstring> | ||
#include <cstdint> | ||
#include "./mock_processor.h" | ||
|
||
const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 | ||
const int64_t kPrefixLen = 24; | ||
|
||
bool ValidDam(void *buffer, std::size_t size) { | ||
return size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0; | ||
} | ||
|
||
void* MockProcessor::ProcessGHPairs(std::size_t *size, const std::vector<double>& pairs) { | ||
*size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x | ||
|
||
int64_t buf_size = *size; | ||
// This memory needs to be freed | ||
char *buf = static_cast<char *>(calloc(*size, 1)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &kDataTypeGHPairs, 8); | ||
|
||
// Simulate encryption by duplicating value 10 times | ||
int index = kPrefixLen; | ||
for (auto value : pairs) { | ||
for (std::size_t i = 0; i < 10; i++) { | ||
memcpy(buf+index, &value, 8); | ||
index += 8; | ||
} | ||
} | ||
|
||
// Save pairs for future operations | ||
this->gh_pairs_ = new std::vector<double>(pairs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this vector freed? |
||
|
||
return buf; | ||
} | ||
|
||
|
||
void* MockProcessor::HandleGHPairs(std::size_t *size, void *buffer, std::size_t buf_size) { | ||
*size = buf_size; | ||
if (!ValidDam(buffer, *size)) { | ||
return buffer; | ||
} | ||
|
||
// For mock, this call is used to set gh_pairs for passive sites | ||
if (!active_) { | ||
int8_t *ptr = static_cast<int8_t *>(buffer); | ||
ptr += kPrefixLen; | ||
double *pairs = reinterpret_cast<double *>(ptr); | ||
std::size_t num = (buf_size - kPrefixLen) / 8; | ||
gh_pairs_ = new std::vector<double>(); | ||
for (std::size_t i = 0; i < num; i += 10) { | ||
gh_pairs_->push_back(pairs[i]); | ||
} | ||
} | ||
|
||
return buffer; | ||
} | ||
|
||
void *MockProcessor::ProcessAggregation(std::size_t *size, std::map<int, std::vector<int>> nodes) { | ||
int total_bin_size = cuts_.back(); | ||
int histo_size = total_bin_size*2; | ||
*size = kPrefixLen + 8*histo_size*nodes.size(); | ||
int64_t buf_size = *size; | ||
int8_t *buf = static_cast<int8_t *>(calloc(buf_size, 1)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &kDataTypeHisto, 8); | ||
|
||
double *histo = reinterpret_cast<double *>(buf + kPrefixLen); | ||
for ( const auto &node : nodes ) { | ||
auto rows = node.second; | ||
for (const auto &row_id : rows) { | ||
auto num = cuts_.size() - 1; | ||
for (std::size_t f = 0; f < num; f++) { | ||
int slot = slots_[f + num*row_id]; | ||
if ((slot < 0) || (slot >= total_bin_size)) { | ||
continue; | ||
} | ||
|
||
auto g = (*gh_pairs_)[row_id*2]; | ||
auto h = (*gh_pairs_)[row_id*2+1]; | ||
histo[slot*2] += g; | ||
histo[slot*2+1] += h; | ||
} | ||
} | ||
histo += histo_size; | ||
} | ||
|
||
return buf; | ||
} | ||
|
||
std::vector<double> MockProcessor::HandleAggregation(void *buffer, std::size_t buf_size) { | ||
std::vector<double> result = std::vector<double>(); | ||
|
||
int8_t* ptr = static_cast<int8_t *>(buffer); | ||
auto rest_size = buf_size; | ||
|
||
while (rest_size > kPrefixLen) { | ||
if (!ValidDam(ptr, rest_size)) { | ||
break; | ||
} | ||
int64_t *size_ptr = reinterpret_cast<int64_t *>(ptr + 8); | ||
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen); | ||
auto array_size = (*size_ptr - kPrefixLen)/8; | ||
result.insert(result.end(), array_start, array_start + array_size); | ||
rest_size -= *size_ptr; | ||
ptr = ptr + *size_ptr; | ||
} | ||
|
||
return result; | ||
} | ||
|
||
void* MockProcessor::ProcessHistograms(std::size_t *size, const std::vector<double>& histograms) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is for test only, let's move it out into the test module. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense, mock serves for testing purpose @nvidianz what do you think? |
||
*size = kPrefixLen + histograms.size()*10*8; // Assume encrypted size is 10x | ||
|
||
int64_t buf_size = *size; | ||
// This memory needs to be freed | ||
char *buf = static_cast<char *>(malloc(buf_size)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &kDataTypeAggregatedHisto, 8); | ||
|
||
// Simulate encryption by duplicating value 10 times | ||
int index = kPrefixLen; | ||
for (auto value : histograms) { | ||
for (std::size_t i = 0; i < 10; i++) { | ||
memcpy(buf+index, &value, 8); | ||
index += 8; | ||
} | ||
} | ||
|
||
return buf; | ||
} | ||
|
||
std::vector<double> MockProcessor::HandleHistograms(void *buffer, std::size_t buf_size) { | ||
std::vector<double> result = std::vector<double>(); | ||
|
||
int8_t* ptr = static_cast<int8_t *>(buffer); | ||
auto rest_size = buf_size; | ||
|
||
while (rest_size > kPrefixLen) { | ||
if (!ValidDam(ptr, rest_size)) { | ||
break; | ||
} | ||
int64_t *size_ptr = reinterpret_cast<int64_t *>(ptr + 8); | ||
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen); | ||
auto array_size = (*size_ptr - kPrefixLen)/8; | ||
auto empty = result.empty(); | ||
if (!empty) { | ||
if (result.size() != array_size / 10) { | ||
std::cout << "Histogram size doesn't match " << result.size() | ||
<< " != " << array_size << std::endl; | ||
return result; | ||
} | ||
} | ||
|
||
for (std::size_t i = 0; i < array_size/10; i++) { | ||
auto value = array_start[i*10]; | ||
if (empty) { | ||
result.push_back(value); | ||
} else { | ||
result[i] += value; | ||
} | ||
} | ||
|
||
rest_size -= *size_ptr; | ||
ptr = ptr + *size_ptr; | ||
} | ||
|
||
return result; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May I ask why the vertical federated learning section is being modified for horizontal learning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one adds additional horizontal functions to the vertical_P2 PR, so it in fact includes everything that PR has, since that PR has not been merged, this one shows all the vertical modifications,
"Modifictions added beyond #10124"