Skip to content

Commit

Permalink
feat: multistep lr scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 16, 2024
1 parent 0ce3f89 commit 8079e76
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 4 additions & 0 deletions models/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def lambda_rule(epoch):
scheduler = lr_scheduler.StepLR(
optimizer, step_size=opt.train_lr_decay_iters, gamma=0.1
)
elif opt.train_lr_policy == "multistep":
scheduler = lr_scheduler.MultiStepLR(
optimizer, milestones=opt.train_lr_steps, gamma=0.1
)
elif opt.train_lr_policy == "plateau":
scheduler = lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.2, threshold=0.01, patience=5
Expand Down
12 changes: 10 additions & 2 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def initialize(self, parser):
"--train_lr_policy",
type=str,
default="linear",
choices=["linear", "step", "plateau", "cosine"],
choices=["linear", "step", "multistep", "plateau", "cosine"],
help="learning rate policy.",
)
parser.add_argument(
Expand All @@ -307,6 +307,14 @@ def initialize(self, parser):
default=50,
help="multiply by a gamma every lr_decay_iters iterations",
)
parser.add_argument(
"--train_lr_steps",
default=[],
nargs="*",
type=int,
help="number of epochs between reductions of the learning rate by gamma=0.1",
)

parser.add_argument(
"--train_nb_img_max_fid",
type=int,
Expand Down Expand Up @@ -700,7 +708,7 @@ def _after_parse(self, opt, set_device=True):
)

# vitclip16 projector only works with input size 224
if opt.D_proj_network_type == "efficientnet":
if "projected_d" in opt.D_netDs and opt.D_proj_network_type == "efficientnet":
if opt.D_proj_interp < 224:
warnings.warn(
"Efficiennet projector has minimal input size of 224, setting D_proj_interp to 224"
Expand Down

0 comments on commit 8079e76

Please sign in to comment.