Skip to content

Commit

Permalink
fix: config tests for old config
Browse files Browse the repository at this point in the history
  • Loading branch information
ankeko committed Aug 20, 2024
1 parent 3b208ca commit 7273211
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
6 changes: 4 additions & 2 deletions configs/ops/train/callbacks/callbacks_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@ callback_dict:
save_model:
# Stores the model after each epoch
_target_: niceml.dlframeworks.keras.callbacks.callback_factories.ModelCallbackFactory
model_subfolder: models
model_filename: model-id_{short_id}-ep{epoch:03d}
# model_subfolder may include model name or declare them separately
model_subfolder: models/model-id_{short_id}-ep{epoch:03d}
# model_subfolder: models
# model_filename: model-id_{short_id}-ep{epoch:03d}
18 changes: 12 additions & 6 deletions niceml/dlframeworks/keras/callbacks/callback_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from os.path import join
from pathlib import Path
from typing import Any, List
from typing import Any, List, Optional

from niceml.dlframeworks.keras.callbacks.csvlogger import CSVLogger
from niceml.dlframeworks.keras.callbacks.modelcheckpoint import (
Expand Down Expand Up @@ -39,19 +39,25 @@ def create_callback(self, exp_context: ExperimentContext):
class ModelCallbackFactory(CallbackFactory):
"""Creates the model checkpoint callback"""

def __init__(self, model_subfolder: str, model_filename: str, **kwargs):
def __init__(
self, model_subfolder: str, model_filename: Optional[str] = None, **kwargs
):
"""
Initializes the ModelCallbackFactory object, which creates
ModelCheckpoint callbacks.
ModelCheckpoint callbacks. If model_filename is not given, it will
be inferred from the model_subfolder. Fileextensions will be ignored.
Args:
model_subfolder: name of the subfolder to save the model in
model_filename: filename of the model file without the file extension
model_filename: filename of the model file without the file extension. If
model_filename is not given, it will be inferred from the model_subfolder
**kwargs: additional keyword arguments for ModelCheckpoint initialization
"""
self.kwargs = kwargs
self.model_subfolder = model_subfolder
self.model_filename = model_filename
self.model_subfolder = (
model_subfolder if model_filename else str(Path(model_subfolder).parent)
)
self.model_filename = model_filename or str(Path(model_subfolder).stem)

def create_callback(self, exp_context: ExperimentContext) -> ModelCheckpoint:
"""
Expand Down
3 changes: 2 additions & 1 deletion template/configs/ops/train/callbacks/callbacks_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ callback_dict:
save_model:
# Stores the model after each epoch
_target_: niceml.dlframeworks.keras.callbacks.callback_factories.ModelCallbackFactory
model_subfolder: models/model-id_{short_id}-ep{epoch:03d}.hdf5
model_subfolder: models
model_filename: model-id_{short_id}-ep{epoch:03d}

0 comments on commit 7273211

Please sign in to comment.