Skip to content

Commit

Permalink
enh: Add on_checkpoint callback that triggers after every checkpoint …
Browse files Browse the repository at this point in the history
…even if the model was not evaluated (#3763)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justin Zhao <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2023
1 parent 266a36c commit 9389b16
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
5 changes: 5 additions & 0 deletions ludwig/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ def should_early_stop(self, trainer, progress_tracker, is_coordinator):
# Triggers early stopping if any callback on any worker returns True
return False

def on_checkpoint(self, trainer, progress_tracker):
"""Called after each checkpoint is passed, regardless of whether the model was evaluated or saved at that
checkpoint."""
pass

def on_save_best_checkpoint(self, trainer, progress_tracker, save_path):
"""Called on every worker immediately after a new best model is checkpointed."""
pass
Expand Down
13 changes: 12 additions & 1 deletion ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
self.enable_profiling = config.enable_profiling
self.steps_per_epoch = 0 # Computed during training, after batcher has been initialized.
self.total_steps = 0 # Computed during training, after batcher has been initialized.
self.total_expected_checkpoints = 0 # Computed during training, after batcher has been initialized.

self.regularization_lambda = config.regularization_lambda
self.regularization_type = config.regularization_type
Expand Down Expand Up @@ -651,7 +652,6 @@ def run_evaluation(
start_time = time.time()
self.callback(lambda c: c.on_eval_start(self, progress_tracker, save_path))

progress_tracker.checkpoint_number += 1
if self.is_coordinator():
logger.info(f"\nRunning evaluation for step: {progress_tracker.steps}, epoch: {progress_tracker.epoch}")

Expand Down Expand Up @@ -915,6 +915,8 @@ def train(
)
final_steps_per_checkpoint = min(final_steps_per_checkpoint, self.total_steps)
early_stopping_steps = final_steps_per_checkpoint * self.early_stop
if not self.skip_save_progress:
self.total_expected_checkpoints = self.total_steps // final_steps_per_checkpoint + self.epochs

# Initialize the learning rate scheduler.
self.scheduler = LRScheduler(
Expand Down Expand Up @@ -993,10 +995,15 @@ def train(
f"{time_utils.strdelta((time.time() - start_time) * 1000.0)}."
)
if not self.skip_save_progress:
progress_tracker.checkpoint_number += 1

checkpoint_manager.save(progress_tracker.steps)
if self.is_coordinator():
progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME))

# Callback that the checkpoint was reached, regardless of whether the model was evaluated.
self.callback(lambda c: c.on_checkpoint(self, progress_tracker))

if not self.skip_save_model and self.skip_all_evaluation:
# All evaluation was skipped, so save the current step as the best so far.
checkpoint_manager.save_best(progress_tracker.steps)
Expand Down Expand Up @@ -1197,10 +1204,14 @@ def _train_loop(
# this should not make a difference, except in the unlikely event an error occurs during eval and we
# want to resume from the last checkpoint, in which case we will lose slightly more progress this way.
if not self.skip_save_progress:
progress_tracker.checkpoint_number += 1
checkpoint_manager.save(progress_tracker.steps)
if self.is_coordinator():
progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME))

# Callback that the checkpoint was reached, regardless of whether the model was evaluated or not.
self.callback(lambda c: c.on_checkpoint(self, progress_tracker))

# If this was the last batch, then increment the epoch counter and invoke the `on_epoch_end` callback.
if batcher.last_batch():
self.callback(lambda c: c.on_epoch_end(self, progress_tracker, save_path))
Expand Down
4 changes: 3 additions & 1 deletion ludwig/trainers/trainer_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def run_evaluation(
start_time = time.time()
self.callback(lambda c: c.on_eval_start(self, progress_tracker, save_path))

progress_tracker.checkpoint_number += 1
if self.is_coordinator():
logger.info(f"\nRunning evaluation for step: {progress_tracker.steps}, epoch: {progress_tracker.epoch}")

Expand Down Expand Up @@ -684,10 +683,13 @@ def train(
f"{time_utils.strdelta((time.time()- start_time) * 1000.0)}."
)
if not self.skip_save_progress:
progress_tracker.checkpoint_number += 1
checkpoint_manager.checkpoint.model = self.model
checkpoint_manager.save(progress_tracker.steps)
progress_tracker.save(os.path.join(save_path, TRAINING_PROGRESS_TRACKER_FILE_NAME))

self.callback(lambda c: c.on_checkpoint(self, progress_tracker))

# Early stop if needed.
if should_break:
break
Expand Down

0 comments on commit 9389b16

Please sign in to comment.