Skip to content

Commit

Permalink
Merge branch 'main' into rpautrat/fix-gluestick-config
Browse files Browse the repository at this point in the history
  • Loading branch information
rpautrat authored Oct 23, 2023
2 parents 815ea7b + e0104fd commit 7e61e61
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions gluefactory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@
"optimizer_options": {}, # optional arguments passed to the optimizer
"lr": 0.001, # learning rate
"lr_schedule": {
"type": None,
"type": None, # string in {factor, exp, member of torch.optim.lr_scheduler}
"start": 0,
"exp_div_10": 0,
"on_epoch": False,
"factor": 1.0,
"options": {}, # add lr_scheduler arguments here
},
"lr_scaling": [(100, ["dampingnet.const"])],
"eval_every_iter": 1000, # interval for evaluation on the validation set
Expand Down Expand Up @@ -141,6 +142,26 @@ def filter_fn(x):
return params


def get_lr_scheduler(optimizer, conf):
"""Get lr scheduler specified by conf.train.lr_schedule."""
if conf.type not in ["factor", "exp", None]:
return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)

# backward compatibility
def lr_fn(it): # noqa: E306
if conf.type is None:
return 1
if conf.type == "factor":
return 1.0 if it < conf.start else conf.factor
if conf.type == "exp":
gam = 10 ** (-1 / conf.exp_div_10)
return 1.0 if it < conf.start else gam
else:
raise ValueError(conf.type)

return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)


def pack_lr_parameters(params, base_lr, lr_scaling):
"""Pack each group of parameters with the respective scaled learning rate."""
filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
Expand Down Expand Up @@ -310,22 +331,7 @@ def sigint_handler(signal, frame):

results = None # fix bug with it saving

def lr_fn(it): # noqa: E306
if conf.train.lr_schedule.type is None:
return 1
if conf.train.lr_schedule.type == "factor":
return (
1.0
if it < conf.train.lr_schedule.start
else conf.train.lr_schedule.factor
)
if conf.train.lr_schedule.type == "exp":
gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10)
return 1.0 if it < conf.train.lr_schedule.start else gam
else:
raise ValueError(conf.train.lr_schedule.type)

lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler)
if args.restore:
optimizer.load_state_dict(init_cp["optimizer"])
if "lr_scheduler" in init_cp:
Expand Down

0 comments on commit 7e61e61

Please sign in to comment.