diff --git a/gluefactory/train.py b/gluefactory/train.py index 08895d72..12ad2075 100644 --- a/gluefactory/train.py +++ b/gluefactory/train.py @@ -48,11 +48,12 @@ "optimizer_options": {}, # optional arguments passed to the optimizer "lr": 0.001, # learning rate "lr_schedule": { - "type": None, + "type": None, # string in {factor, exp, member of torch.optim.lr_scheduler} "start": 0, "exp_div_10": 0, "on_epoch": False, "factor": 1.0, + "options": {}, # add lr_scheduler arguments here }, "lr_scaling": [(100, ["dampingnet.const"])], "eval_every_iter": 1000, # interval for evaluation on the validation set @@ -141,6 +142,26 @@ def filter_fn(x): return params +def get_lr_scheduler(optimizer, conf): + """Get lr scheduler specified by conf.train.lr_schedule.""" + if conf.type not in ["factor", "exp", None]: + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) + + # backward compatibility + def lr_fn(it): # noqa: E306 + if conf.type is None: + return 1 + if conf.type == "factor": + return 1.0 if it < conf.start else conf.factor + if conf.type == "exp": + gam = 10 ** (-1 / conf.exp_div_10) + return 1.0 if it < conf.start else gam + else: + raise ValueError(conf.type) + + return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn) + + def pack_lr_parameters(params, base_lr, lr_scaling): """Pack each group of parameters with the respective scaled learning rate.""" filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names])) @@ -310,22 +331,7 @@ def sigint_handler(signal, frame): results = None # fix bug with it saving - def lr_fn(it): # noqa: E306 - if conf.train.lr_schedule.type is None: - return 1 - if conf.train.lr_schedule.type == "factor": - return ( - 1.0 - if it < conf.train.lr_schedule.start - else conf.train.lr_schedule.factor - ) - if conf.train.lr_schedule.type == "exp": - gam = 10 ** (-1 / conf.train.lr_schedule.exp_div_10) - return 1.0 if it < conf.train.lr_schedule.start else gam - else: - raise ValueError(conf.train.lr_schedule.type) - - lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn) + lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_scheduler) if args.restore: optimizer.load_state_dict(init_cp["optimizer"]) if "lr_scheduler" in init_cp: