From 34fb2ddd3a0bc813303c0f6e21531ed75fc4f6c7 Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Mon, 2 Oct 2023 19:35:53 +0000 Subject: [PATCH] feat(ml): optim weigth decay parameter control --- models/base_model.py | 3 +++ models/cut_model.py | 4 ++++ models/cycle_gan_model.py | 2 ++ models/palette_model.py | 1 + models/re_cycle_gan_semantic_mask_model.py | 1 + options/train_options.py | 6 ++++++ train.py | 20 +++++++++++--------- 7 files changed, 28 insertions(+), 9 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index b6b173629..3ea23d717 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -221,6 +221,7 @@ def init_semantic_cls(self, opt): self.netCLS.parameters(), lr=opt.train_sem_lr_f_s, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) if opt.train_cls_regression: @@ -322,6 +323,7 @@ def init_semantic_mask(self, opt): ), lr=opt.train_sem_lr_f_s, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) else: self.optimizer_f_s = opt.optim( @@ -329,6 +331,7 @@ def init_semantic_mask(self, opt): self.netf_s.parameters(), lr=opt.train_sem_lr_f_s, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) self.optimizers.append(self.optimizer_f_s) diff --git a/models/cut_model.py b/models/cut_model.py index 05e3719cf..f4aab55a0 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -267,6 +267,7 @@ def __init__(self, opt, rank): self.netG_A.parameters(), lr=opt.train_G_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) if self.opt.model_multimodal: self.criterionZ = torch.nn.L1Loss() @@ -275,6 +276,7 @@ def __init__(self, opt, rank): self.netE.parameters(), lr=opt.train_G_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) if len(self.discriminators_names) > 0: @@ -294,6 +296,7 @@ def __init__(self, opt, rank): D_parameters, lr=opt.train_D_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) self.optimizers.append(self.optimizer_G) @@ -435,6 +438,7 @@ def data_dependent_initialize(self, data): self.netF.parameters(), lr=self.opt.train_G_lr, betas=(self.opt.train_beta1, self.opt.train_beta2), + weight_decay=self.opt.train_optim_weight_decay, ) self.optimizers.append(self.optimizer_F) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 25a1882e2..cad848325 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -152,6 +152,7 @@ def __init__(self, opt, rank): itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.train_G_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) D_parameters = itertools.chain( @@ -166,6 +167,7 @@ def __init__(self, opt, rank): D_parameters, lr=opt.train_D_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) self.optimizers.append(self.optimizer_G) diff --git a/models/palette_model.py b/models/palette_model.py index fc0ff5376..f076f6a85 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -339,6 +339,7 @@ def __init__(self, opt, rank): G_parameters, lr=opt.train_G_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) self.optimizers.append(self.optimizer_G) diff --git a/models/re_cycle_gan_semantic_mask_model.py b/models/re_cycle_gan_semantic_mask_model.py index 7b34bd98f..58d486cdb 100644 --- a/models/re_cycle_gan_semantic_mask_model.py +++ b/models/re_cycle_gan_semantic_mask_model.py @@ -78,6 +78,7 @@ def __init__(self, opt): itertools.chain(self.netP_A.parameters(), self.netP_B.parameters()), lr=opt.alg_re_P_lr, betas=(opt.train_beta1, opt.train_beta2), + weight_decay=opt.train_optim_weight_decay, ) self.optimizers.append(self.optimizer_P) diff --git a/options/train_options.py b/options/train_options.py index f26dcef91..17e80f3b1 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -173,6 +173,12 @@ def initialize(self, parser): choices=["adam", "radam", "adamw", "lion"], help="optimizer (adam, radam, adamw, ...)", ) + parser.add_argument( + "--train_optim_weight_decay", + type=float, + default=0.0, + help="weight decay for optimizer", + ) parser.add_argument( "--train_load_iter", type=int, diff --git a/train.py b/train.py index 46717c9b6..8252fdcf3 100644 --- a/train.py +++ b/train.py @@ -51,16 +51,18 @@ def setup(rank, world_size, port): dist.init_process_group("nccl", rank=rank, world_size=world_size) -def optim(opt, params, lr, betas): +def optim(opt, params, lr, betas, weight_decay): print("Using ", opt.train_optim, " as optimizer") if opt.train_optim == "adam": - return torch.optim.Adam(params, lr, betas) + return torch.optim.Adam(params, lr, betas, weight_decay=weight_decay) elif opt.train_optim == "radam": - return torch.optim.RAdam(params, lr, betas) + return torch.optim.RAdam(params, lr, betas, weight_decay=weight_decay) elif opt.train_optim == "adamw": - return torch.optim.AdamW(params, lr, betas) + if weight_decay == 0.0: + weight_decay = 0.01 # default value + return torch.optim.AdamW(params, lr, betas, weight_decay=weight_decay) elif opt.train_optim == "lion": - return Lion(params, lr, betas) + return Lion(params, lr, betas, weight_decay) def signal_handler(sig, frame): @@ -257,12 +259,12 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): ) model.save_networks("latest") - model.export_networks("latest") + # model.export_networks("latest") if opt.train_save_by_iter: save_suffix = "iter_%d" % total_iters model.save_networks(save_suffix) - model.export_networks(save_suffix) + # model.export_networks(save_suffix) if total_iters % opt.train_metrics_every < batch_size and ( opt.train_compute_metrics_test @@ -341,8 +343,8 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): model.save_networks("latest") model.save_networks(epoch) - model.export_networks("latest") - model.export_networks(epoch) + # model.export_networks("latest") + # model.export_networks(epoch) if rank_0: print(