Skip to content

Commit

Permalink
v0.0.32: add pretrained state loading
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidzee committed Aug 16, 2023
1 parent cc2a66e commit 88604b0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions lightning_toolbox/training/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
# initialization settings
save_hparams: bool = True,
initialize_superclass: bool = True,
pretrained_state: th.Optional[str] = None,
) -> None:
if initialize_superclass:
super().__init__()
Expand Down Expand Up @@ -127,6 +128,12 @@ def __init__(
if model is not None or model_args is not None or model_cls is not None:
self.model = model if model is not None else dy.eval(model_cls)(**(model_args or dict()))

# load pretrained checkpoint
# TODO: add smart loading of state dict, e.g. if the checkpoint has different structure,
# load only the common parts and ignore the rest
if pretrained_state is not None:
self.load_state_dict(torch.load(pretrained_state)["state_dict"])

@functools.cached_property
def __optimizers_is_active_list(self):
if self.optimizers_list is None:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
setup(
name="lightning_toolbox",
packages=find_packages(include=["lightning_toolbox", "lightning_toolbox.*"]),
version="0.0.32",
version="0.0.33",
license="MIT",
description="A collection of utilities for PyTorch Lightning.",
long_description=long_description,
Expand Down

0 comments on commit 88604b0

Please sign in to comment.