diff --git a/train.py b/train.py index 674e228d..94096d00 100644 --- a/train.py +++ b/train.py @@ -89,9 +89,9 @@ def train(opt): frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac opt.current_lr = opt.learning_rate * decay_factor - utils.set_lr(optimizer, opt.current_lr) # set the decayed rate else: opt.current_lr = opt.learning_rate + utils.set_lr(optimizer, opt.current_lr) # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every