-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing loss scaling scheduler callback and schedulers #270
Implementing loss scaling scheduler callback and schedulers #270
Conversation
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
This serves as the main interface for controlling the schedules during training
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
We can set the task scaling values, but if the task key isn't set the data may not be present
Seems like most schedules will need the same configuration
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
Signed-off-by: Lee, Kin Long Kelvin <[email protected]>
I think the same tests fail as in #266 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one general comment, otherwise looks good!
matsciml/lightning/callbacks.py
Outdated
target_key = schedule.key | ||
self._logger.debug( | ||
f"Attempting to advance {target_key} schedule on step." | ||
) | ||
try: | ||
new_scaling_value = schedule.step() | ||
pl_module.task_loss_scaling[target_key] = new_scaling_value | ||
self._logger.debug( | ||
f"Advanced {target_key} to new value: {new_scaling_value}" | ||
) | ||
except StopIteration: | ||
self._logger.warning( | ||
f"{target_key} has run out of scheduled values; this may be unintentional." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you could combine this common bit between the on_x_end functions, and pass in a 'step' or 'epoch' string variable to use in the if statement and log message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in d3b97da
This PR adds a new callback, called
LossScalingScheduler
, and scheduler classes that will modify the relative loss weights over the course of training.LinearScalingSchedule
andSigmoidScalingSchedule
. The former will generate a linear ramp from start to end over steps or epochs, and the latter gives a gradual ramp up in the form of a sigmoid curve.LossScalingScheduler
is configured by passing schedules that are mapped to task keys, and for every training step or epoch (set by the schedules), applies a new value of the weighting to the appropriate task.examples/callbacks/loss_scheduling.py
is provided to show how it is configured.This is useful for implementing curricula, where we prioritize learning of different properties over time.