diff --git a/configs/vision/pathology/offline/segmentation/bcss.yaml b/configs/vision/pathology/offline/segmentation/bcss.yaml index b7c0f616..b0297274 100644 --- a/configs/vision/pathology/offline/segmentation/bcss.yaml +++ b/configs/vision/pathology/offline/segmentation/bcss.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/segmentation/consep.yaml b/configs/vision/pathology/offline/segmentation/consep.yaml index 79af2962..97bc55c6 100644 --- a/configs/vision/pathology/offline/segmentation/consep.yaml +++ b/configs/vision/pathology/offline/segmentation/consep.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/segmentation/monusac.yaml b/configs/vision/pathology/offline/segmentation/monusac.yaml index 587f9984..1fe200ba 100644 --- a/configs/vision/pathology/offline/segmentation/monusac.yaml +++ b/configs/vision/pathology/offline/segmentation/monusac.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml index 38080f1a..921aa1ff 100644 --- a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml @@ -5,6 +5,10 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/online/segmentation/bcss.yaml b/configs/vision/pathology/online/segmentation/bcss.yaml index 2c343f13..9633764f 100644 --- a/configs/vision/pathology/online/segmentation/bcss.yaml +++ b/configs/vision/pathology/online/segmentation/bcss.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/online/segmentation/consep.yaml b/configs/vision/pathology/online/segmentation/consep.yaml index 06f181df..0b319c0a 100644 --- a/configs/vision/pathology/online/segmentation/consep.yaml +++ b/configs/vision/pathology/online/segmentation/consep.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/consep} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/online/segmentation/monusac.yaml b/configs/vision/pathology/online/segmentation/monusac.yaml index b7f7ec21..75c405d2 100644 --- a/configs/vision/pathology/online/segmentation/monusac.yaml +++ b/configs/vision/pathology/online/segmentation/monusac.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/monusac} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 550} log_every_n_steps: 4 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml index 2671ec40..0cbb4647 100644 --- a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml @@ -5,6 +5,10 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000} + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/radiology/offline/segmentation/lits.yaml b/configs/vision/radiology/offline/segmentation/lits.yaml index d9e0c490..fccc40a0 100644 --- a/configs/vision/radiology/offline/segmentation/lits.yaml +++ b/configs/vision/radiology/offline/segmentation/lits.yaml @@ -5,6 +5,10 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index a0059e34..59235bd1 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -5,6 +5,10 @@ trainer: n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/radiology/online/segmentation/lits.yaml b/configs/vision/radiology/online/segmentation/lits.yaml index 3d8d2fc5..799c60b7 100644 --- a/configs/vision/radiology/online/segmentation/lits.yaml +++ b/configs/vision/radiology/online/segmentation/lits.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index cff4c88e..6a0ebd33 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -6,6 +6,10 @@ trainer: default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced} max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000} log_every_n_steps: 6 + plugins: + class_path: eva.core.plugins.SubmoduleTorchCheckpointIO + init_args: + submodule: decoder callbacks: - class_path: eva.callbacks.ConfigurationLogger - class_path: lightning.pytorch.callbacks.TQDMProgressBar diff --git a/src/eva/core/plugins/__init__.py b/src/eva/core/plugins/__init__.py new file mode 100644 index 00000000..a2088a36 --- /dev/null +++ b/src/eva/core/plugins/__init__.py @@ -0,0 +1,5 @@ +"""Plug-ins API.""" + +from eva.core.plugins.io import SubmoduleTorchCheckpointIO + +__all__ = ["SubmoduleTorchCheckpointIO"] diff --git a/src/eva/core/plugins/io/__init__.py b/src/eva/core/plugins/io/__init__.py new file mode 100644 index 00000000..c38623b0 --- /dev/null +++ b/src/eva/core/plugins/io/__init__.py @@ -0,0 +1,5 @@ +"""IO plug-ins API.""" + +from eva.core.plugins.io.torch_io import SubmoduleTorchCheckpointIO + +__all__ = ["SubmoduleTorchCheckpointIO"] diff --git a/src/eva/core/plugins/io/torch_io.py b/src/eva/core/plugins/io/torch_io.py new file mode 100644 index 00000000..22863ee7 --- /dev/null +++ b/src/eva/core/plugins/io/torch_io.py @@ -0,0 +1,67 @@ +"""Extended version of the IO model checkpoint plugin `TorchCheckpointIO`.""" + +import os +from typing import Any, Dict + +from lightning.fabric.utilities import cloud_io +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch import plugins +from typing_extensions import override + + +class SubmoduleTorchCheckpointIO(plugins.TorchCheckpointIO): + """IO plugin which allows to additionally save only a sub-part of the full model.""" + + def __init__(self, submodule: str) -> None: + """Initializes the plugin. + + Args: + submodule: The name of the submodule to additionally save. + """ + super().__init__() + + self._submodule = submodule + + @override + def save_checkpoint( + self, + checkpoint: Dict[str, Any], + path: _PATH, + storage_options: Any | None = None, + ) -> None: + super().save_checkpoint(checkpoint, path, storage_options) + self._save_submodule(checkpoint["state_dict"], path) + + @override + def remove_checkpoint(self, path: _PATH) -> None: + super().remove_checkpoint(path) + self._remove_submodule(path) + + def _save_submodule(self, module_checkpoint: Dict[str, Any], module_path: _PATH) -> None: + """Saves the submodule.""" + path = self._submodule_path(module_path) + state_dict = self._submodule_state_dict(module_checkpoint) + + os.makedirs(os.path.dirname(path), exist_ok=True) + cloud_io._atomic_save(state_dict, path) + + def _remove_submodule(self, module_path: _PATH) -> None: + """Removes the submodule.""" + path = self._submodule_path(module_path) + fs = cloud_io.get_filesystem(path) + if fs.exists(path): + fs.rm(path, recursive=True) + + def _submodule_path(self, module_path: _PATH) -> str: + """Constructs and returns the submodule checkpoint path.""" + root, basename = os.path.split(module_path) + return os.path.join(root, self._submodule, basename.replace(".ckpt", ".pth")) + + def _submodule_state_dict(self, module_state_dict: Dict[str, Any]) -> Dict[str, Any]: + """Returns the submodule `state_dict`.""" + key = self._submodule if self._submodule.endswith(".") else self._submodule + "." + return { + module.replace(key, ""): weights + for module, weights in module_state_dict.items() + if module.startswith(key) + }