-
Notifications
You must be signed in to change notification settings - Fork 0
/
Strategy.py
68 lines (55 loc) · 1.93 KB
/
Strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from abc import ABC, abstractmethod
from pickle import NONE
from Client import Client
from typing import Sequence
import torch
from Courier import Courier
class Strategy(ABC):
def __init__(self, courier:Courier, clients:Sequence[Client]):
self.courier = courier
self.clients = clients
@abstractmethod
def aggregate(self):
pass
@abstractmethod
def update_all(self):
pass
class SyncConcatStrategy(Strategy):
def __init__(self, courier, clients):
super().__init__(courier, clients)
def is_all_set(self):
return not any(elem is None for elem in self.courier.message_pool.values())
def update_all(self, loss):
loss.backward()
map(lambda client: client.update(), self.clients)
def aggregate(self, eval=False):
# Check all clients sent their embedding
if not eval:
for client in self.clients:
client.fit()
else:
for client in self.clients:
client.predict()
# print(self.courier.message_pool.values())
if self.is_all_set():
emb_list = list(self.courier.message_pool.values())
embs = torch.cat(emb_list, 1)
self.courier.flush()
return embs
class SyncSTGConcatStrategy(SyncConcatStrategy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def update_all(self, loss, server_reg_loss):
reg_loss = server_reg_loss
for client in self.clients:
reg_loss += client.model.get_reg_loss()
reg_loss = reg_loss/(len(self.clients)+1)
total_loss = reg_loss + loss
total_loss.backward()
map(lambda client: client.update(), self.clients)
def number_of_features(self):
num_feats = 0
for client in self.clients:
_, num = client.model.get_gates()
num_feats += num
return num_feats