diff --git a/docs/source/Parameters.rst b/docs/source/Parameters.rst index 3935691d..fa0be994 100644 --- a/docs/source/Parameters.rst +++ b/docs/source/Parameters.rst @@ -46,6 +46,8 @@ Parameters for Federated Setting - ``server``: the server proposes candidate split points according to the range of each feature in horizontal FedTree. - ``party``: the parties propose possible split points. Then, the server merge them and sample at most num_max_bin candidate split points in horizontal FedTree. +* ``key_length`` [default=512, type=int] + - Number of bits of the key used in encryption. Parameters for GBDTs -------------------- diff --git a/include/FedTree/Encryption/paillier.h b/include/FedTree/Encryption/paillier.h index 053d3117..00ae1ace 100644 --- a/include/FedTree/Encryption/paillier.h +++ b/include/FedTree/Encryption/paillier.h @@ -15,7 +15,7 @@ class Paillier { // this->random = source.random; return *this; } - void keygen(long keyLength); + void keygen(int keyLength); NTL::ZZ encrypt(const NTL::ZZ &message) const; @@ -27,7 +27,7 @@ class Paillier { NTL::ZZ modulus; NTL::ZZ generator; - long keyLength; + int keyLength; //private: NTL::ZZ p, q; diff --git a/include/FedTree/FL/FLparam.h b/include/FedTree/FL/FLparam.h index 44b2d7af..4974fd86 100644 --- a/include/FedTree/FL/FLparam.h +++ b/include/FedTree/FL/FLparam.h @@ -32,6 +32,7 @@ class FLParam { bool joint_prediction; // For vertical FL, whether multiple parties jointly conduct prediction or not. bool partial_model; // For vertical FL. If set to true, each party gets a partial tree with the split nodes using the local features. Otherwise, each party gets a full tree with all features. GBDTParam gbdt_param; // parameters for the gbdt training + int key_length; // number of bits of the key used for encryption }; diff --git a/include/FedTree/FL/server.h b/include/FedTree/FL/server.h index 87c9820d..ee189015 100644 --- a/include/FedTree/FL/server.h +++ b/include/FedTree/FL/server.h @@ -55,14 +55,14 @@ class Server : public Party { party.paillier = paillier; } - void homo_init() { + void homo_init(int keylength) { #ifdef USE_CUDA paillier.keygen(); // pailler_gmp = Pailler(1024); // paillier = Paillier(paillier_gmp); // paillier.keygen(); #else - paillier.keygen(512); + paillier.keygen(keylength); #endif } diff --git a/src/FedTree/Encryption/paillier.cpp b/src/FedTree/Encryption/paillier.cpp index 279854cf..09cc8186 100644 --- a/src/FedTree/Encryption/paillier.cpp +++ b/src/FedTree/Encryption/paillier.cpp @@ -40,7 +40,7 @@ NTL::ZZ lcm(const NTL::ZZ &x, const NTL::ZZ &y) { return lcm; } -void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, long keyLength) { +void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, int keyLength) { /* Prime pair generation function. Generates a prime pair in same bit length. * * Parameters @@ -63,7 +63,7 @@ void GenPrimePair(NTL::ZZ &p, NTL::ZZ &q, long keyLength) { Paillier::Paillier() = default; -void Paillier::keygen(long keyLength) { +void Paillier::keygen(int keyLength) { /* Paillier parameters generation function. Generates paillier parameters from scrach. * * Parameters diff --git a/src/FedTree/FL/FLtrainer.cpp b/src/FedTree/FL/FLtrainer.cpp index e89d1abf..d00ed675 100644 --- a/src/FedTree/FL/FLtrainer.cpp +++ b/src/FedTree/FL/FLtrainer.cpp @@ -26,7 +26,7 @@ void FLtrainer::horizontal_fl_trainer(vector &parties, Server &server, FL if (params.privacy_tech == "he") { LOG(INFO) << "Start HE Init"; // server generate public key and private key - server.homo_init(); + server.homo_init(params.key_length); // server distribute public key to rest of parties for (int i = 0; i < n_parties; i++) { parties[i].paillier = server.paillier; @@ -553,7 +553,7 @@ void FLtrainer::vertical_fl_trainer(vector &parties, Server &server, FLPa auto t1 = timer.now(); temp_gradients.resize(server.booster.gradients.size()); temp_gradients.copy_from(server.booster.gradients); - server.homo_init(); + server.homo_init(params.key_length); server.encrypt_gh_pairs(server.booster.gradients); auto t2 = timer.now(); std::chrono::duration t3 = t2 - t1; diff --git a/src/FedTree/FL/distributed_server.cpp b/src/FedTree/FL/distributed_server.cpp index 56547bc1..8e5dcae6 100644 --- a/src/FedTree/FL/distributed_server.cpp +++ b/src/FedTree/FL/distributed_server.cpp @@ -1160,7 +1160,7 @@ grpc::Status DistributedServer::ScoreReduce(grpc::ServerContext* context, const grpc::Status DistributedServer::TriggerHomoInit(grpc::ServerContext *context, const fedtree::PID *request, fedtree::Ready *response) { // LOG(INFO) << "computation HomoInit start"; - homo_init(); + homo_init(param.key_length); homo_init_success = true; // LOG(INFO) << "computation HomoInit end"; return grpc::Status::OK; diff --git a/src/FedTree/parser.cpp b/src/FedTree/parser.cpp index 9b932daa..a1123c29 100644 --- a/src/FedTree/parser.cpp +++ b/src/FedTree/parser.cpp @@ -47,6 +47,7 @@ void Parser::parse_param(FLParam &fl_param, char *file_path) { fl_param.seed = 42; fl_param.n_features = -1; fl_param.joint_prediction = true; + fl_param.key_length = 512; GBDTParam *gbdt_param = &fl_param.gbdt_param; @@ -117,6 +118,8 @@ void Parser::parse_param(FLParam &fl_param, char *file_path) { fl_param.n_features = atoi(val); else if ((str_name.compare("joint_prediction") == 0)) fl_param.joint_prediction = atoi(val); + else if ((str_name.compare("key_length") == 0)) + fl_param.key_length = atoi(val); // GBDT params else if ((str_name.compare("max_depth") == 0) || (str_name.compare("depth") == 0)) gbdt_param->depth = atoi(val);