Skip to content

Commit

Permalink
add key_length #63
Browse files Browse the repository at this point in the history
  • Loading branch information
QinbinLi committed Apr 1, 2023
1 parent 83ab595 commit f4f85c2
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/source/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Parameters for Federated Setting
- ``server``: the server proposes candidate split points according to the range of each feature in horizontal FedTree.
- ``party``: the parties propose possible split points. Then, the server merge them and sample at most num_max_bin candidate split points in horizontal FedTree.

* ``key_length`` [default=512, type=int]
- Number of bits of the key used in encryption.
Parameters for GBDTs
--------------------

Expand Down
4 changes: 2 additions & 2 deletions include/FedTree/Encryption/paillier.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Paillier {
// this->random = source.random;
return *this;
}
void keygen(long keyLength);
void keygen(int keyLength);

NTL::ZZ encrypt(const NTL::ZZ &message) const;

Expand All @@ -27,7 +27,7 @@ class Paillier {

NTL::ZZ modulus;
NTL::ZZ generator;
long keyLength;
int keyLength;

//private:
NTL::ZZ p, q;
Expand Down
1 change: 1 addition & 0 deletions include/FedTree/FL/FLparam.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class FLParam {
bool joint_prediction; // For vertical FL, whether multiple parties jointly conduct prediction or not.
bool partial_model; // For vertical FL. If set to true, each party gets a partial tree with the split nodes using the local features. Otherwise, each party gets a full tree with all features.
GBDTParam gbdt_param; // parameters for the gbdt training
int key_length; // number of bits of the key used for encryption
};


Expand Down
4 changes: 2 additions & 2 deletions include/FedTree/FL/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ class Server : public Party {
party.paillier = paillier;
}

void homo_init() {
void homo_init(int keylength) {
#ifdef USE_CUDA
paillier.keygen();
// pailler_gmp = Pailler(1024);
// paillier = Paillier(paillier_gmp);
// paillier.keygen();
#else
paillier.keygen(512);
paillier.keygen(keylength);
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions src/FedTree/Encryption/paillier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ NTL::ZZ lcm(const NTL::ZZ &x, const NTL::ZZ &y) {
return lcm;
}

void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, long keyLength) {
void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, int keyLength) {
/* Prime pair generation function. Generates a prime pair in same bit length.
*
* Parameters
Expand All @@ -63,7 +63,7 @@ void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, long keyLength) {

Paillier::Paillier() = default;

void Paillier::keygen(long keyLength) {
void Paillier::keygen(int keyLength) {
/* Paillier parameters generation function. Generates paillier parameters from scrach.
*
* Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/FedTree/FL/FLtrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void FLtrainer::horizontal_fl_trainer(vector<Party> &parties, Server &server, FL
if (params.privacy_tech == "he") {
LOG(INFO) << "Start HE Init";
// server generate public key and private key
server.homo_init();
server.homo_init(params.key_length);
// server distribute public key to rest of parties
for (int i = 0; i < n_parties; i++) {
parties[i].paillier = server.paillier;
Expand Down Expand Up @@ -553,7 +553,7 @@ void FLtrainer::vertical_fl_trainer(vector<Party> &parties, Server &server, FLPa
auto t1 = timer.now();
temp_gradients.resize(server.booster.gradients.size());
temp_gradients.copy_from(server.booster.gradients);
server.homo_init();
server.homo_init(params.key_length);
server.encrypt_gh_pairs(server.booster.gradients);
auto t2 = timer.now();
std::chrono::duration<float> t3 = t2 - t1;
Expand Down
2 changes: 1 addition & 1 deletion src/FedTree/FL/distributed_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ grpc::Status DistributedServer::ScoreReduce(grpc::ServerContext* context, const
grpc::Status DistributedServer::TriggerHomoInit(grpc::ServerContext *context, const fedtree::PID *request,
fedtree::Ready *response) {
// LOG(INFO) << "computation HomoInit start";
homo_init();
homo_init(param.key_length);
homo_init_success = true;
// LOG(INFO) << "computation HomoInit end";
return grpc::Status::OK;
Expand Down
3 changes: 3 additions & 0 deletions src/FedTree/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ void Parser::parse_param(FLParam &fl_param, char *file_path) {
fl_param.seed = 42;
fl_param.n_features = -1;
fl_param.joint_prediction = true;
fl_param.key_length = 512;

GBDTParam *gbdt_param = &fl_param.gbdt_param;

Expand Down Expand Up @@ -117,6 +118,8 @@ void Parser::parse_param(FLParam &fl_param, char *file_path) {
fl_param.n_features = atoi(val);
else if ((str_name.compare("joint_prediction") == 0))
fl_param.joint_prediction = atoi(val);
else if ((str_name.compare("key_length") == 0))
fl_param.key_length = atoi(val);
// GBDT params
else if ((str_name.compare("max_depth") == 0) || (str_name.compare("depth") == 0))
gbdt_param->depth = atoi(val);
Expand Down

0 comments on commit f4f85c2

Please sign in to comment.