Skip to content

Commit

Permalink
Fix missing checkpoint callbacks (ecmwf#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie authored Nov 8, 2024
1 parent 9eb68a7 commit a26b05c
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def nestedget(conf: DictConfig, key: str, default: Any) -> Any:
]


def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None:
"""Get checkpointing callback."""
def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint]:
"""Get checkpointing callbacks."""
if not config.diagnostics.get("enable_checkpointing", True):
return []

Expand Down Expand Up @@ -89,6 +89,7 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non
n_saved,
)

checkpoint_callbacks = []
if not config.diagnostics.profiler:
for save_key, (
name,
Expand All @@ -97,29 +98,27 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non
) in ckpt_frequency_save_dict.items():
if save_frequency is not None:
LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency)
return (
checkpoint_callbacks.append(
# save_top_k: the save_top_k flag can either save the best or the last k checkpoints
# depending on the monitor flag on ModelCheckpoint.
# See https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html for reference
[
AnemoiCheckpoint(
config=config,
filename=name,
save_last=True,
**{save_key: save_frequency},
# if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved
save_top_k=save_n_models,
monitor="step",
mode="max",
**checkpoint_settings,
),
]
AnemoiCheckpoint(
config=config,
filename=name,
save_last=True,
**{save_key: save_frequency},
# if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved
save_top_k=save_n_models,
monitor="step",
mode="max",
**checkpoint_settings,
),
)
LOGGER.debug("Not setting up a checkpoint callback with %s", save_key)
else:
# the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints
LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!")
return None
return checkpoint_callbacks


def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]:
Expand Down Expand Up @@ -180,9 +179,7 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
trainer_callbacks: list[Callback] = []

# Get Checkpoint callback
checkpoint_callback = _get_checkpoint_callback(config)
if checkpoint_callback is not None:
trainer_callbacks.extend(checkpoint_callback)
trainer_callbacks.extend(_get_checkpoint_callback(config))

# Base callbacks
trainer_callbacks.extend(
Expand Down

0 comments on commit a26b05c

Please sign in to comment.