-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.py
83 lines (80 loc) · 2.85 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
from federal import *
import torch.multiprocessing as mp
from models import EPOCHS
parser = argparse.ArgumentParser(description='client end usage')
parser.add_argument('--mode',
dest='mode',
action='store',
# type=bool,
default=None,
help='test mode or not')
parser.add_argument('--model',
dest='model',
action='store',
choices={'fl-agnns', 'fl-random',
'fl-darts', 'fl-graphnas', "fl-fednas"},
default='fl-agnns',
help='search model')
parser.add_argument('--client',
dest='client',
action='store',
type=int,
default=3,
help='the number of clients in the search')
args = parser.parse_args()
# start clients for evaluating a specific code
if __name__ == '__main__':
mp.set_start_method('spawn')
if args.mode == 'test' or args.mode == "eval":
clients = []
for j in range(args.client):
clients.append(ClientCommonNet(j))
processes = []
for client in clients:
process = mp.Process(target=client.work)
process.start()
processes.append(process)
for process in processes:
process.join()
elif args.model == 'fl-random':
for i in range(5):
clients = []
for j in range(args.client):
clients.append(ClientCommonNet(j))
processes = []
for client in clients:
process = mp.Process(target=client.work)
process.start()
processes.append(process)
for process in processes:
process.join()
elif args.model == "fl-graphnas":
for i in range(EPOCHS):
clients = []
for j in range(args.client):
clients.append(ClientCommonNet(j))
processes = []
for client in clients:
process = mp.Process(target=client.work)
process.start()
processes.append(process)
for process in processes:
process.join()
# break
else:
clients = []
for i in range(args.client):
if args.model == 'fl-agnns':
clients.append(ClientSuperNet(i))
elif args.model == 'fl-darts':
clients.append(ClientDarts(i))
elif args.model == 'fl-fednas':
clients.append(ClientFedNas(i))
processes = []
for client in clients:
process = mp.Process(target=client.work)
process.start()
processes.append(process)
for process in processes:
process.join()