-
Notifications
You must be signed in to change notification settings - Fork 0
/
user.py
357 lines (307 loc) · 14.5 KB
/
user.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
import torch
import datasets
from config import DEVICE, LABEL_NAME
import numpy as np
from sklearn.cluster import KMeans
from optimize import optimize_transmission_matrices
from shuffle import shuffle_data
from data import fedmd
class User:
def __init__(self, id, devices, testset, model, sampling) -> None:
self.id = id
self.devices = devices
self.model = model
self.testset = testset
self.kd_dataset = None
self.sampling = sampling
self.log = []
self.latency = 0
def _aggregate_updates(
self,
epochs,
learning_rate=0.001,
T=3.0,
soft_target_loss_weight=0.4,
ce_loss_weight=0.6,
):
# Train server model on the dataset using kd
student = self.model().to(DEVICE)
student.load_state_dict(torch.load("checkpoints/server.pth"))
student.train()
optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)
# TODO : Test 2 options
# option 1 -> each device acts as a teacher separately (takes more time)
# option 2 -> average the teacher logits from all devices (can be done in parallel)
teachers = []
for device in self.devices:
teacher = device.model().to(DEVICE)
teacher.load_state_dict(torch.load(f"checkpoints/device_{device.id}.pth"))
teacher.eval()
teachers.append(teacher)
train_loader = torch.utils.data.DataLoader(
self.kd_dataset, shuffle=True, drop_last=True, batch_size=32, num_workers=3
)
ce_loss = torch.nn.CrossEntropyLoss()
kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
running_loss = 0.0
running_kd_loss = 0.0
running_ce_loss = 0.0
running_accuracy = 0.0
for _ in range(epochs):
for batch in train_loader:
inputs, labels = batch["img"].to(DEVICE), batch[LABEL_NAME].to(DEVICE)
optimizer.zero_grad()
# Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights
with torch.no_grad():
# Keep the teacher logits for the soft targets
teacher_targets = []
for teacher in teachers:
logits = teacher(inputs)
targets = torch.nn.functional.softmax(logits / T, dim=-1)
teacher_targets.append(targets)
# Forward pass with the student model
student_logits = student(inputs)
soft_targets = torch.mean(torch.stack(teacher_targets), dim=0)
soft_prob = torch.nn.functional.log_softmax(student_logits / T, dim=-1)
# Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper "Distilling the knowledge in a neural network"
soft_targets_loss = kl_loss(soft_prob, soft_targets) * (T ** 2)
# Calculate the true label loss
label_loss = ce_loss(student_logits, labels)
# Weighted sum of the two losses
loss = (soft_target_loss_weight * soft_targets_loss) + (
ce_loss_weight * label_loss
)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_kd_loss += soft_targets_loss.item()
running_ce_loss += label_loss.item()
running_accuracy += (
(torch.max(student_logits, 1)[1] == labels).sum().item()
)
torch.save(student.state_dict(), f"checkpoints/user_{self.id}.pth")
# Train all the devices belonging to the user
# Steps 11-15 in the ShuffleFL Algorithm
def train(self, kd_epochs, on_device_epochs):
if self.kd_dataset is None:
for device in self.devices:
torch.save(
device.model().state_dict(), f"checkpoints/device_{device.id}.pth"
)
self.create_kd_dataset()
for device in self.devices:
device.update_model(self.model, self.kd_dataset)
device.train(on_device_epochs)
self._aggregate_updates(epochs=kd_epochs)
self.test()
def n_samples(self):
return len(self.kd_dataset)
def test(self):
net = self.model().to(DEVICE)
net.load_state_dict(torch.load(f"checkpoints/user_{self.id}.pth"))
net.eval()
testloader = torch.utils.data.DataLoader(
self.testset, batch_size=32, shuffle=False, num_workers=3
)
correct, total = 0, 0
with torch.no_grad():
for batch in testloader:
images, labels = batch["img"].to(DEVICE), batch[LABEL_NAME].to(DEVICE)
outputs = net(images)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
total += labels.size(0)
print(f"USER {self.id} test accuracy: {correct / total}")
self.log.append(correct / total)
# Functions to determine if there are non-iid data on the devices
# Types of non-iid data:
# Feature distribution skew: detect by comparing the mean and variance of the features for corresponding classes
# Label distribution skew: detect by comparing the number of samples for each class
# Quantity skew: detect by comparing the number of samples for each class
# Concept drift: detect by comparing the performance of the model on the devices
def create_kd_dataset(self):
# Sample a dataset from the devices
# Types of sampling:
# Full
# Size-proportional
# Fair
# Upload-proportional
# Balance-proportional
# Balanced
# Adaptive
# Shuffle-optimized
if self.kd_dataset is not None:
return
if self.sampling == "full":
self.kd_dataset = self._sample_full()
elif self.sampling == "size-proportional":
self.kd_dataset = self._sample_size_proportional(0.4)
elif self.sampling == "fair":
self.kd_dataset = self._sample_fair(800)
elif self.sampling == "balance-proportional":
self.kd_dataset = self._sample_balance_proportional(4000)
elif self.sampling == "upload-proportional":
self.kd_dataset = self._sample_upload_proportional(4000)
elif self.sampling == "balanced":
self.kd_dataset = self._sample_balanced(40)
elif self.sampling == "adaptive":
self.kd_dataset = self._sample_adaptive()
elif self.sampling == "shuffle-optimized":
self.kd_dataset = self._shuffle()
elif self.sampling == "fedmd":
self.kd_dataset = fedmd()
else:
raise ValueError(f"Sampling method {self.sampling} not recognized")
# Functions to sample the devices
def _sample_devices(self, fractions):
dataset = None
for device, f in zip(self.devices, fractions):
if dataset is None:
dataset = device.sample(f)
else:
dataset = datasets.concatenate_datasets(
[dataset, device.sample(f)]
)
return dataset
def _sample_devices_amount(self, amounts):
dataset = None
for device, amount in zip(self.devices, amounts):
if dataset is None:
dataset = device.sample_amount(amount)
else:
dataset = datasets.concatenate_datasets(
[dataset, device.sample_amount(amount)]
)
return dataset
def _sample_full(self):
# Each device contributes all of its dataset
self.latency = (max([device.uplink * len(device.dataset) for device in self.devices]))
return self._sample_devices([1 for _ in self.devices])
def _sample_size_proportional(self, p):
# Each device contributes a percentage of its dataset
fractions = [p for _ in self.devices]
self.latency = max([device.uplink * len(device.dataset) * p for device in self.devices])
return self._sample_devices(fractions)
def _sample_fair(self, k):
# Each device contributes the same number of samples
amounts = [k for _ in self.devices]
self.latency = (max([device.uplink for device in self.devices]) * k)
dataset = self._sample_devices_amount(amounts)
print(f"Length of fair dataset: {len(dataset)}")
return dataset
def _sample_balance_proportional(self, dataset_length):
# Take the inverse of the imbalance
balances = [1 / (device.imbalance() + 1e-12) for device in self.devices]
# Sample a higher percentage for the devices with less imbalance
fractions = [b / sum(balances) for b in balances]
amounts = [f * dataset_length for f in fractions]
self.latency= (max([device.uplink * amount for device, amount in zip(self.devices, amounts)]))
return self._sample_devices_amount(amounts)
def _sample_upload_proportional(self, dataset_length):
# Sample more from the devices with faster upload speeds
upload_speeds = [1 / device.uplink for device in self.devices]
fractions = [u / sum(upload_speeds) for u in upload_speeds]
amounts = [f * dataset_length for f in fractions]
self.latency = (max([device.uplink * amount for device, amount in zip(self.devices, amounts)]))
return self._sample_devices_amount(amounts)
def _sample_balanced(self, samples_per_class):
# Sample in a greedy way (starting with the device with faster upload) so that the final dataset's labels are balanced
devices_by_upload = sorted(self.devices, key=lambda x: x.uplink)
n_classes = 100
s = [samples_per_class for _ in range(n_classes)]
dataset = None
latencies = []
# Get the label distribution for each device
for device in devices_by_upload:
# Sample as much as possible from the device
device_latency = 0
for class_id in range(n_classes):
if s[class_id] <= 0:
continue
samples = device.sample_amount_class(s[class_id], class_id)
if dataset is None:
dataset = samples
else:
dataset = datasets.concatenate_datasets([dataset, samples])
s[class_id] -= len(samples)
device_latency += device.uplink * len(samples)
latencies.append(device_latency)
self.latency = (max(latencies))
print(f"Length of balanced dataset: {len(dataset)}")
return dataset
def _sample_adaptive(self):
# Identify the kind of non-iid data on the devices
# Use a sampling technique for the devices based on the identified non-iid data
if self._label_distribution_skew():
return self._sample_balance_proportional(4000)
elif self._bottleneck():
return self._sample_upload_proportional(4000)
elif self._quantity_skew():
return self._sample_size_proportional(0.4)
else:
return self._sample_fair(800)
# Implements section 4.4 from ShuffleFL
def _reduce_dimensionality(self):
kmeans_estimator = self._compute_centroids()
for device in self.devices:
device.cluster(kmeans_estimator)
def _shuffle(self):
# Reduce dimensionality of the transmission matrices
# ShuffleFL step 7, 8
self._reduce_dimensionality()
# User optimizes the transmission matrices
# ShuffleFL step 9
adaptive_scaling_factor = 1
cluster_distributions = [device.cluster_distribution() for device in self.devices]
uplinks = [1/device.uplink for device in self.devices]
downlinks = [1/device.uplink for device in self.devices]
computes = [device.compute for device in self.devices]
transition_matrices = optimize_transmission_matrices(adaptive_scaling_factor, cluster_distributions, uplinks, downlinks, computes)
# Shuffle the data and update the transition matrices
# Implements Equation 1 from ShuffleFL
ds = [device.dataset for device in self.devices]
clusters = [device.clusters for device in self.devices]
print("Shuffle optimized")
res_datasets, kd_dataset, latency = shuffle_data(ds, clusters, cluster_distributions, transition_matrices, self.devices)
print("Shuffle complete")
self.latency = latency
print(self.latency)
# Update average capability
# Update the devices with the new datasets
for device, dataset in zip(self.devices, res_datasets):
device.dataset = dataset
return kd_dataset
def _compute_centroids(self):
# Sample based on the uplink rate of the devices
dataset = self._sample_devices([1 for _ in self.devices])
features = np.array(dataset["img"]).reshape(len(dataset), -1)
# Cluster datapoints to k classes using KMeans
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters).fit(features)
# Save the kmeans object to a file
# Save the lda object to a file
return kmeans
# Functions to determine data heterogeneity on the devices
def _quantity_skew(self) -> bool:
# Check the number of samples the devices
quantities = np.array([device.n_samples() for device in self.devices])
# Check the Z score
z = np.abs((quantities - np.mean(quantities)) / np.std(quantities) + 10e-6)
return np.any(z > 1.5)
def _label_distribution_skew(self) -> bool:
# Check the number of samples for each class
labels = np.array([device.imbalance() for device in self.devices])
# Determine if the number of samples is skewed using the coefficient of variation
z = np.abs((labels - np.mean(labels)) / (np.std(labels)) + 10e-6)
return np.any(z > 1.5)
def _bottleneck(self) -> bool:
# Check the performance of the model on the devices
latencies = np.array([device.uplink for device in self.devices])
# Check the Z score
z = np.abs((latencies - np.mean(latencies)) / (np.std(latencies)) + 10e-6)
return np.any(z > 1.5)
def _quality_skew(self) -> bool:
# Sample a small amount of data from each device
return False
def _feature_skew(self) -> bool:
# Sample a small amount of data from each device
return False