Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616333867
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Mar 16, 2024
1 parent 17bce3c commit d7ea2aa
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions swirl_dynamics/projects/ergodic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,23 @@ def main(argv):
else:
raise NotImplementedError(f"Unknown experiment: {config.experiment}")

# Slicing the decay options in the learning rate scheduler.
if "decay_rate" in config:
decay_rate = config.decay_rate
else:
decay_rate = 0.5

if "num_steps_for_decrease_lr" in config:
num_steps_for_decrease_lr = config.num_steps_for_decrease_lr
else:
num_steps_for_decrease_lr = config.train_steps_per_cycle

if config.use_lr_scheduler:
optimizer = optax.adam(
learning_rate=optax.exponential_decay(
init_value=config.lr,
transition_steps=config.train_steps_per_cycle,
decay_rate=0.5,
transition_steps=num_steps_for_decrease_lr,
decay_rate=decay_rate,
staircase=True,
)
)
Expand Down

0 comments on commit d7ea2aa

Please sign in to comment.