diff --git a/experiments.py b/experiments.py index 9894453..672ead4 100644 --- a/experiments.py +++ b/experiments.py @@ -670,7 +670,7 @@ def local_train_net_scaffold(nets, selected, global_model, c_nets, c_global, arg logger.info("net %d final test acc %f" % (net_id, testacc)) avg_acc += testacc for key in total_delta: - total_delta[key] /= len(selected) + total_delta[key] /= args.n_parties c_global_para = c_global.state_dict() for key in c_global_para: if c_global_para[key].type() == 'torch.LongTensor':