diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f6a7b894..f3597843 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -56,8 +56,8 @@ def nestedget(conf: DictConfig, key: str, default: Any) -> Any: ] -def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: - """Get checkpointing callback.""" +def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint]: + """Get checkpointing callbacks.""" if not config.diagnostics.get("enable_checkpointing", True): return [] @@ -89,6 +89,7 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non n_saved, ) + checkpoint_callbacks = [] if not config.diagnostics.profiler: for save_key, ( name, @@ -97,29 +98,27 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non ) in ckpt_frequency_save_dict.items(): if save_frequency is not None: LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency) - return ( + checkpoint_callbacks.append( # save_top_k: the save_top_k flag can either save the best or the last k checkpoints # depending on the monitor flag on ModelCheckpoint. # See https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html for reference - [ - AnemoiCheckpoint( - config=config, - filename=name, - save_last=True, - **{save_key: save_frequency}, - # if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved - save_top_k=save_n_models, - monitor="step", - mode="max", - **checkpoint_settings, - ), - ] + AnemoiCheckpoint( + config=config, + filename=name, + save_last=True, + **{save_key: save_frequency}, + # if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved + save_top_k=save_n_models, + monitor="step", + mode="max", + **checkpoint_settings, + ), ) LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) else: # the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!") - return None + return checkpoint_callbacks def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: @@ -180,9 +179,7 @@ def get_callbacks(config: DictConfig) -> list[Callback]: trainer_callbacks: list[Callback] = [] # Get Checkpoint callback - checkpoint_callback = _get_checkpoint_callback(config) - if checkpoint_callback is not None: - trainer_callbacks.extend(checkpoint_callback) + trainer_callbacks.extend(_get_checkpoint_callback(config)) # Base callbacks trainer_callbacks.extend(