-
Notifications
You must be signed in to change notification settings - Fork 1
/
async_strategy.py
137 lines (114 loc) · 7.94 KB
/
async_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import numpy as np
import pickle
import math
from flwr.common import Parameters
import flwr.server.strategy.aggregate as agg
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters
from typing import List, Tuple
from flwr.common import NDArrays, log
from logging import DEBUG, WARNING
from time import time
class AsynchronousStrategy:
"""Abstract base class for all asynchronous strategies."""
def __init__(self, total_samples: int, staleness_alpha: float, fedasync_mixing_alpha: float, fedasync_a: float, num_clients: int, async_aggregation_strategy: str,
use_staleness: bool, use_sample_weighing: bool, send_gradients: bool, server_artificial_delay: bool) -> None:
self.total_samples = total_samples
self.staleness_alpha = staleness_alpha
self.fedasync_a = fedasync_a
self.fedasync_mixing_alpha = fedasync_mixing_alpha
self.num_clients = num_clients
self.async_aggregation_strategy = async_aggregation_strategy
self.use_staleness = use_staleness
self.use_sample_weighing = use_sample_weighing
self.send_gradients = send_gradients
self.server_artificial_delay = server_artificial_delay
def average(self, global_parameters: Parameters, model_update_parameters: Parameters, t_diff: float, num_samples: int) -> Parameters:
"""Compute the average of the global and client parameters."""
if self.async_aggregation_strategy == "fedasync":
if self.send_gradients:
return self.weighted_merge_fedasync(global_parameters, model_update_parameters, t_diff, num_samples)
else:
return self.weighted_average_fedasync(global_parameters, model_update_parameters, t_diff, num_samples)
# elif self.async_aggregation_strategy == "asyncfeded":
# return self.weighted_average_asyncfeded(global_parameters, model_update_parameters, t_diff, num_samples)
# elif self.async_aggregation_strategy == "unweighted":
# return self.unweighted_average(global_parameters, model_update_parameters)
else:
raise ValueError(
f"Invalid async aggregation strategy: {self.async_aggregation_strategy}")
def unweighted_average(self, global_parameters: Parameters, model_update_parameters: Parameters) -> Parameters:
"""Compute the unweighted average of the global and client parameters."""
return ndarrays_to_parameters(agg.aggregate([(parameters_to_ndarrays(global_parameters), 1),
(parameters_to_ndarrays(model_update_parameters), 1)]))
def weighted_average_asyncfeded(self, global_parameters: Parameters, model_update_parameters: Parameters, t_diff: float, num_samples: int) -> Parameters:
"""Compute the weighted average of the global and client parameters. Inspired by the paper asyncFedED : https://arxiv.org/pdf/2205.13797.pdf"""
return ndarrays_to_parameters(self.aggregate_asyncfeded(parameters_to_ndarrays(global_parameters), parameters_to_ndarrays(model_update_parameters), t_diff, num_samples=num_samples))
def get_sample_weight_coeff(self, num_samples: int) -> float:
"""Compute the sample weight coefficient."""
return num_samples / self.total_samples
def weighted_average_fedasync(self, global_parameters: Parameters, model_update_parameters: Parameters, t_diff: float, num_samples: int) -> Parameters:
"""Compute the weighted average of the global and client parameters. Inspired by the paper Fedasync : https://arxiv.org/pdf/1903.03934.pdf"""
return ndarrays_to_parameters(self.aggregate_fedasync(parameters_to_ndarrays(global_parameters), parameters_to_ndarrays(model_update_parameters), t_diff, num_samples=num_samples))
def busy_wait(self, seconds: float) -> None:
"""Busy wait for the specified number of seconds."""
start = time()
while time() - start < seconds:
pass
def aggregate_fedasync(self, global_param_arr: NDArrays, model_update_param_arr: NDArrays, t_diff: float, num_samples: int) -> NDArrays:
"""Compute weighted average with the formula params_new = (1-alpha) * params_old + alpha * (model_update_params)"""
# Calculate the total number of examples used during training
alpha_coeff = self.fedasync_mixing_alpha
if self.use_staleness:
alpha_coeff *= self.get_staleness_weight_coeff_fedasync_poly(
t_diff=t_diff)
if self.use_sample_weighing:
alpha_coeff *= self.get_sample_weight_coeff(num_samples)
if self.server_artificial_delay:
self.busy_wait(0.5)
# log(DEBUG, f"t_diff: {t_diff}\nalpha_coeff: {alpha_coeff}")
return [(1 - alpha_coeff) * layer_global + alpha_coeff * layer_update for layer_global, layer_update in zip(global_param_arr, model_update_param_arr)]
def weighted_merge_fedasync(self, global_parameters: Parameters, gradients: Parameters, t_diff: float, num_samples: int) -> Parameters:
"""Add gradients to the global model. Inspired by the paper Fedasync : https://arxiv.org/pdf/1903.03934.pdf
It is not however the same procedure as in original paper, because they aggregate MODELS and we aggregate GRADIENTS.
"""
if self.server_artificial_delay:
self.busy_wait(1)
return ndarrays_to_parameters(self.add_grads_fedasync(parameters_to_ndarrays(global_parameters), parameters_to_ndarrays(gradients), t_diff, num_samples=num_samples))
def add_grads_fedasync(self, global_param_arr: NDArrays, gradients_arr: NDArrays, t_diff: float, num_samples: int) -> NDArrays:
"""Compute weighted average with the formula params_new = (1-alpha) * params_old + alpha * (params_old + update_grads)"""
# Calculate the total number of examples used during training
alpha_coeff = self.fedasync_mixing_alpha
if self.use_staleness:
alpha_coeff *= self.get_staleness_weight_coeff_fedasync_poly(
t_diff=t_diff)
if self.use_sample_weighing:
alpha_coeff *= self.get_sample_weight_coeff(num_samples)
# log(DEBUG, f"t_diff: {t_diff}\nalpha_coeff: {alpha_coeff}")
return [(1 - alpha_coeff) * layer_global + alpha_coeff * (layer_global + layer_grad) for layer_global, layer_grad in zip(global_param_arr, gradients_arr)]
# See paper: https://arxiv.org/pdf/2205.13797.pdf
def aggregate_asyncfeded(self, global_param_arr: NDArrays, model_update_param_arr: NDArrays, t_diff: float, num_samples: int) -> NDArrays:
"""Computing the new parameters using the formula params_new = params_old + nu * (model_update_params)
Where nu is influenced by the staleness of the model update and/or the number of samples.
"""
eta = 1
if self.use_staleness:
# Staleness weighted coefficient
eta *= self.get_staleness_weight_coeff_paflm(t_diff=t_diff)
if self.use_sample_weighing:
eta *= self.get_sample_weight_coeff(num_samples)
log(DEBUG, f"t_diff: {t_diff}\nnu: {eta}")
return [layer_global + eta * (layer_update - layer_global) for layer_global, layer_update in zip(global_param_arr, model_update_param_arr)]
# See paper for more details : https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9022982
def get_staleness_weight_coeff_paflm(self, t_diff: float) -> float:
mu_staleness = t_diff
exponent = ((1 / float(self.num_clients)) * mu_staleness) - 1
beta_P = math.pow(self.staleness_alpha, exponent)
return beta_P
# Paper: https://arxiv.org/pdf/1903.03934.pdf
def get_staleness_weight_coeff_fedasync_constant(self) -> float:
return 1.0
def get_staleness_weight_coeff_fedasync_poly(self, t_diff: float) -> float:
return math.pow(t_diff + 1, -self.fedasync_a)
def get_staleness_weight_coeff_fedasync_hinge(self, t_diff: float, a: float = 10, b: float = 4) -> float:
return 1 if t_diff <= b else 1 / (a * (t_diff - b) + 1)