Skip to content

Commit

Permalink
feat(ml): optim weigth decay parameter control
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 4, 2023
1 parent 95efb8a commit 34fb2dd
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 9 deletions.
3 changes: 3 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def init_semantic_cls(self, opt):
self.netCLS.parameters(),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)

if opt.train_cls_regression:
Expand Down Expand Up @@ -322,13 +323,15 @@ def init_semantic_mask(self, opt):
),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)
else:
self.optimizer_f_s = opt.optim(
opt,
self.netf_s.parameters(),
lr=opt.train_sem_lr_f_s,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)

self.optimizers.append(self.optimizer_f_s)
Expand Down
4 changes: 4 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def __init__(self, opt, rank):
self.netG_A.parameters(),
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)
if self.opt.model_multimodal:
self.criterionZ = torch.nn.L1Loss()
Expand All @@ -275,6 +276,7 @@ def __init__(self, opt, rank):
self.netE.parameters(),
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)

if len(self.discriminators_names) > 0:
Expand All @@ -294,6 +296,7 @@ def __init__(self, opt, rank):
D_parameters,
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)

self.optimizers.append(self.optimizer_G)
Expand Down Expand Up @@ -435,6 +438,7 @@ def data_dependent_initialize(self, data):
self.netF.parameters(),
lr=self.opt.train_G_lr,
betas=(self.opt.train_beta1, self.opt.train_beta2),
weight_decay=self.opt.train_optim_weight_decay,
)
self.optimizers.append(self.optimizer_F)

Expand Down
2 changes: 2 additions & 0 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __init__(self, opt, rank):
itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)

D_parameters = itertools.chain(
Expand All @@ -166,6 +167,7 @@ def __init__(self, opt, rank):
D_parameters,
lr=opt.train_D_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)

self.optimizers.append(self.optimizer_G)
Expand Down
1 change: 1 addition & 0 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def __init__(self, opt, rank):
G_parameters,
lr=opt.train_G_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)
self.optimizers.append(self.optimizer_G)

Expand Down
1 change: 1 addition & 0 deletions models/re_cycle_gan_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, opt):
itertools.chain(self.netP_A.parameters(), self.netP_B.parameters()),
lr=opt.alg_re_P_lr,
betas=(opt.train_beta1, opt.train_beta2),
weight_decay=opt.train_optim_weight_decay,
)
self.optimizers.append(self.optimizer_P)

Expand Down
6 changes: 6 additions & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def initialize(self, parser):
choices=["adam", "radam", "adamw", "lion"],
help="optimizer (adam, radam, adamw, ...)",
)
parser.add_argument(
"--train_optim_weight_decay",
type=float,
default=0.0,
help="weight decay for optimizer",
)
parser.add_argument(
"--train_load_iter",
type=int,
Expand Down
20 changes: 11 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,18 @@ def setup(rank, world_size, port):
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def optim(opt, params, lr, betas):
def optim(opt, params, lr, betas, weight_decay):
print("Using ", opt.train_optim, " as optimizer")
if opt.train_optim == "adam":
return torch.optim.Adam(params, lr, betas)
return torch.optim.Adam(params, lr, betas, weight_decay=weight_decay)
elif opt.train_optim == "radam":
return torch.optim.RAdam(params, lr, betas)
return torch.optim.RAdam(params, lr, betas, weight_decay=weight_decay)
elif opt.train_optim == "adamw":
return torch.optim.AdamW(params, lr, betas)
if weight_decay == 0.0:
weight_decay = 0.01 # default value
return torch.optim.AdamW(params, lr, betas, weight_decay=weight_decay)
elif opt.train_optim == "lion":
return Lion(params, lr, betas)
return Lion(params, lr, betas, weight_decay)


def signal_handler(sig, frame):
Expand Down Expand Up @@ -257,12 +259,12 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal):
)

model.save_networks("latest")
model.export_networks("latest")
# model.export_networks("latest")

if opt.train_save_by_iter:
save_suffix = "iter_%d" % total_iters
model.save_networks(save_suffix)
model.export_networks(save_suffix)
# model.export_networks(save_suffix)

if total_iters % opt.train_metrics_every < batch_size and (
opt.train_compute_metrics_test
Expand Down Expand Up @@ -341,8 +343,8 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal):
model.save_networks("latest")
model.save_networks(epoch)

model.export_networks("latest")
model.export_networks(epoch)
# model.export_networks("latest")
# model.export_networks(epoch)

if rank_0:
print(
Expand Down

0 comments on commit 34fb2dd

Please sign in to comment.