diff --git a/configs/ops/train/callbacks/callbacks_base.yaml b/configs/ops/train/callbacks/callbacks_base.yaml index 3739351..41badcf 100644 --- a/configs/ops/train/callbacks/callbacks_base.yaml +++ b/configs/ops/train/callbacks/callbacks_base.yaml @@ -12,5 +12,7 @@ callback_dict: save_model: # Stores the model after each epoch _target_: niceml.dlframeworks.keras.callbacks.callback_factories.ModelCallbackFactory - model_subfolder: models - model_filename: model-id_{short_id}-ep{epoch:03d} +# model_subfolder may include model name or declare them separately + model_subfolder: models/model-id_{short_id}-ep{epoch:03d} +# model_subfolder: models +# model_filename: model-id_{short_id}-ep{epoch:03d} diff --git a/niceml/dlframeworks/keras/callbacks/callback_factories.py b/niceml/dlframeworks/keras/callbacks/callback_factories.py index de6c0ce..2c219ba 100644 --- a/niceml/dlframeworks/keras/callbacks/callback_factories.py +++ b/niceml/dlframeworks/keras/callbacks/callback_factories.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from os.path import join from pathlib import Path -from typing import Any, List +from typing import Any, List, Optional from niceml.dlframeworks.keras.callbacks.csvlogger import CSVLogger from niceml.dlframeworks.keras.callbacks.modelcheckpoint import ( @@ -39,19 +39,25 @@ def create_callback(self, exp_context: ExperimentContext): class ModelCallbackFactory(CallbackFactory): """Creates the model checkpoint callback""" - def __init__(self, model_subfolder: str, model_filename: str, **kwargs): + def __init__( + self, model_subfolder: str, model_filename: Optional[str] = None, **kwargs + ): """ Initializes the ModelCallbackFactory object, which creates - ModelCheckpoint callbacks. + ModelCheckpoint callbacks. If model_filename is not given, it will + be inferred from the model_subfolder. Fileextensions will be ignored. Args: model_subfolder: name of the subfolder to save the model in - model_filename: filename of the model file without the file extension + model_filename: filename of the model file without the file extension. If + model_filename is not given, it will be inferred from the model_subfolder **kwargs: additional keyword arguments for ModelCheckpoint initialization """ self.kwargs = kwargs - self.model_subfolder = model_subfolder - self.model_filename = model_filename + self.model_subfolder = ( + model_subfolder if model_filename else str(Path(model_subfolder).parent) + ) + self.model_filename = model_filename or str(Path(model_subfolder).stem) def create_callback(self, exp_context: ExperimentContext) -> ModelCheckpoint: """ diff --git a/template/configs/ops/train/callbacks/callbacks_base.yaml b/template/configs/ops/train/callbacks/callbacks_base.yaml index bc7565a..3739351 100644 --- a/template/configs/ops/train/callbacks/callbacks_base.yaml +++ b/template/configs/ops/train/callbacks/callbacks_base.yaml @@ -12,4 +12,5 @@ callback_dict: save_model: # Stores the model after each epoch _target_: niceml.dlframeworks.keras.callbacks.callback_factories.ModelCallbackFactory - model_subfolder: models/model-id_{short_id}-ep{epoch:03d}.hdf5 + model_subfolder: models + model_filename: model-id_{short_id}-ep{epoch:03d}