Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save submodule plugin #707

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/vision/pathology/offline/segmentation/bcss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/pathology/offline/segmentation/consep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/pathology/offline/segmentation/monusac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ trainer:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000}
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/pathology/online/segmentation/bcss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/bcss}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/pathology/online/segmentation/consep.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/consep}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 513}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/pathology/online/segmentation/monusac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224}/monusac}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 550}
log_every_n_steps: 4
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ trainer:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/total_segmentator_2d}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 20000}
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/radiology/offline/segmentation/lits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ trainer:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ trainer:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
4 changes: 4 additions & 0 deletions configs/vision/radiology/online/segmentation/lits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ trainer:
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/lits_balanced}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 500000}
log_every_n_steps: 6
plugins:
class_path: eva.core.plugins.SubmoduleTorchCheckpointIO
init_args:
submodule: decoder
callbacks:
- class_path: eva.callbacks.ConfigurationLogger
- class_path: lightning.pytorch.callbacks.TQDMProgressBar
Expand Down
5 changes: 5 additions & 0 deletions src/eva/core/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Plug-ins API."""

from eva.core.plugins.io import SubmoduleTorchCheckpointIO

__all__ = ["SubmoduleTorchCheckpointIO"]
5 changes: 5 additions & 0 deletions src/eva/core/plugins/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""IO plug-ins API."""

from eva.core.plugins.io.torch_io import SubmoduleTorchCheckpointIO

__all__ = ["SubmoduleTorchCheckpointIO"]
67 changes: 67 additions & 0 deletions src/eva/core/plugins/io/torch_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Extended version of the IO model checkpoint plugin `TorchCheckpointIO`."""

import os
from typing import Any, Dict

from lightning.fabric.utilities import cloud_io
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import plugins
from typing_extensions import override


class SubmoduleTorchCheckpointIO(plugins.TorchCheckpointIO):
"""IO plugin which allows to additionally save only a sub-part of the full model."""

def __init__(self, submodule: str) -> None:
"""Initializes the plugin.

Args:
submodule: The name of the submodule to additionally save.
"""
super().__init__()

self._submodule = submodule

@override
def save_checkpoint(
self,
checkpoint: Dict[str, Any],
path: _PATH,
storage_options: Any | None = None,
) -> None:
super().save_checkpoint(checkpoint, path, storage_options)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also save the checkpoint of the full module to path if I understand correctly?
For big FMs / encoders those checkpoints can be quite heavy, it'd be nice to have the option to only save the decoder checkpoint.

self._save_submodule(checkpoint["state_dict"], path)

@override
def remove_checkpoint(self, path: _PATH) -> None:
super().remove_checkpoint(path)
self._remove_submodule(path)

def _save_submodule(self, module_checkpoint: Dict[str, Any], module_path: _PATH) -> None:
"""Saves the submodule."""
path = self._submodule_path(module_path)
state_dict = self._submodule_state_dict(module_checkpoint)

os.makedirs(os.path.dirname(path), exist_ok=True)
cloud_io._atomic_save(state_dict, path)

def _remove_submodule(self, module_path: _PATH) -> None:
"""Removes the submodule."""
path = self._submodule_path(module_path)
fs = cloud_io.get_filesystem(path)
if fs.exists(path):
fs.rm(path, recursive=True)

def _submodule_path(self, module_path: _PATH) -> str:
"""Constructs and returns the submodule checkpoint path."""
root, basename = os.path.split(module_path)
return os.path.join(root, self._submodule, basename.replace(".ckpt", ".pth"))

def _submodule_state_dict(self, module_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Returns the submodule `state_dict`."""
key = self._submodule if self._submodule.endswith(".") else self._submodule + "."
return {
module.replace(key, ""): weights
for module, weights in module_state_dict.items()
if module.startswith(key)
}