diff --git a/trojanzoo/trainer.py b/trojanzoo/trainer.py index af0e9d28..582c43b1 100644 --- a/trojanzoo/trainer.py +++ b/trojanzoo/trainer.py @@ -22,7 +22,9 @@ class Trainer: name = 'trainer' param_list = ['optim_args', 'train_args', 'writer_args', - 'optimizer', 'lr_scheduler', 'pre_conditioner', 'writer'] + 'optimizer', 'lr_scheduler', + 'pre_conditioner', 'model_ema', + 'writer'] @classmethod def add_argument(cls, group: argparse._ArgumentGroup):