-
Notifications
You must be signed in to change notification settings - Fork 2
/
server.py
109 lines (82 loc) · 5.04 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#Importing necessary libraries
import flwr as fl
import numpy as np
import encryption as encr
import tenseal as ts
import filedata as fd
class SaveModelStrategy(fl.server.strategy.FedAvg):
def __init__(self, min_fit_clients, min_available_clients):
super().__init__()
self.min_fit_clients = min_fit_clients
self.min_available_clients = min_available_clients
self.avg_accuracy = 0.0
#The server calls create_context function from encryption.py file
#The code is developed in a way that the server generates the secret key and public key since it is the starting point and synchronises the communication
if encr.Enc_needed.encryption_needed.value:
encr.create_context()
def aggregate_fit(self, rnd, results, failures):
if len(results) < self.min_available_clients:
print(f"Not enough clients available (have {len(results)}, need {self.min_available_clients}). Skipping round {rnd}.")
return None
#Loading public key to perform computations on encrypted data
public_key_context = ts.context_from(fd.read_data("public_key.txt"))
if encr.Enc_needed.encryption_needed.value == 1: #Full encryption is selected
#Declaration of array to store aggregated parameters
aggregated_weights = []
#Loading encrytped vector list of client - 1
inp1_proto_ex = fd.read_data('data_encrypted_Client1.txt')
inp1_ex = ts.lazy_ckks_tensor_from(inp1_proto_ex)
inp1_ex.link_context(public_key_context)
#Loading encrypted vector list of client - 2
inp2_proto_ex = fd.read_data('data_encrypted_Client2.txt')
inp2_ex = ts.lazy_ckks_tensor_from(inp2_proto_ex)
inp2_ex.link_context(public_key_context)
#Adding the parameters
results1_ex = (inp1_ex) + (inp2_ex)
#Dividing with number of clients -> Averaging
denominator_plain_ex = ts.plain_tensor([0.5])
denominator_ckks_ex = ts.ckks_tensor(public_key_context, denominator_plain_ex)
results_ex = results1_ex * denominator_ckks_ex
#Storing the aggregated result in a file
result_ex_file_path = 'result_ex.txt'
fd.write_data(result_ex_file_path, results_ex.serialize())
#As Flower framework does not CKKS encrypted objects, aggregation is by-passed with user-defined function (see above computations)
#In order to continue simulation, aggregation is performed here with in-built functions
aggregated_weights = super().aggregate_fit(rnd, results, failures)
elif encr.Enc_needed.encryption_needed.value == 2: #Partial encryption is selected
#Declaration of array to store aggregated parameters
aggregated_weights = []
#Loading encrytped vector list of client - 1
inp1_proto_ex = fd.read_data('data_encrypted_2_Client1.txt')
inp1_ex = ts.lazy_ckks_tensor_from(inp1_proto_ex)
inp1_ex.link_context(public_key_context)
#Loading encrypted vector list of client - 2
inp2_proto_ex = fd.read_data('data_encrypted_2_Client2.txt')
inp2_ex = ts.lazy_ckks_tensor_from(inp2_proto_ex)
inp2_ex.link_context(public_key_context)
#Adding the parameters
results2_ex = (inp1_ex) + (inp2_ex)
#Dividing with number of clients -> Averaging
denominator_plain_ex = ts.plain_tensor([0.5])
denominator_ckks_ex = ts.ckks_tensor(public_key_context, denominator_plain_ex)
results_ex = results2_ex * denominator_ckks_ex
#Storing the aggregated result in a file
result_ex_file_path = 'result_ex_2.txt'
fd.write_data(result_ex_file_path, results_ex.serialize())
#As Flower framework does not CKKS encrypted objects, aggregation is by-passed with user-defined function (see above computations)
#In order to continue simulation, aggregation is performed here with in-built functions
aggregated_weights = super().aggregate_fit(rnd, results, failures)
else: #No encryption is selected
aggregated_weights = super().aggregate_fit(rnd, results, failures)
return aggregated_weights
# Create strategy and run server
min_fit_clients = 2
min_available_clients = 2
strategy = SaveModelStrategy(min_fit_clients, min_available_clients)
# Start Flower server for three rounds of federated learning
fl.server.start_server(
server_address='localhost:8080',
config=fl.server.ServerConfig(num_rounds = 10),
grpc_max_message_length=1024*1024*1024,
strategy=strategy
)