Skip to content

Commit

Permalink
Update experiments.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sjtudyq authored Jul 8, 2022
1 parent 38e6e93 commit e2ce570
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,19 @@ def init_nets(net_configs, dropout_p, n_parties, args):
elif args.dataset in {'a9a', 'covtype', 'rcv1', 'SUSY'}:
n_classes = 2
if args.use_projection_head:
add = ""
if "mnist" in args.dataset and args.model == "simple-cnn":
add = "-mnist"
for net_i in range(n_parties):
net = ModelFedCon(args.model, args.out_dim, n_classes, net_configs)
net = ModelFedCon(args.model+add, args.out_dim, n_classes, net_configs)
nets[net_i] = net
else:
if args.alg == 'moon':
add = ""
if "mnist" in args.dataset and args.model == "simple-cnn":
add = "-mnist"
for net_i in range(n_parties):
net = ModelFedCon_noheader(args.model, args.out_dim, n_classes, net_configs)
net = ModelFedCon_noheader(args.model+add, args.out_dim, n_classes, net_configs)
nets[net_i] = net
else:
for net_i in range(n_parties):
Expand Down

0 comments on commit e2ce570

Please sign in to comment.