From 88604b0d44acd0d3807759b73aa81a2bafc4e04a Mon Sep 17 00:00:00 2001 From: Vahid Zehtab <33608325+vahidzee@users.noreply.github.com> Date: Wed, 16 Aug 2023 12:36:22 -0400 Subject: [PATCH] v0.0.32: add pretrained state loading --- lightning_toolbox/training/module.py | 7 +++++++ setup.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lightning_toolbox/training/module.py b/lightning_toolbox/training/module.py index d362960..9cbfc32 100644 --- a/lightning_toolbox/training/module.py +++ b/lightning_toolbox/training/module.py @@ -59,6 +59,7 @@ def __init__( # initialization settings save_hparams: bool = True, initialize_superclass: bool = True, + pretrained_state: th.Optional[str] = None, ) -> None: if initialize_superclass: super().__init__() @@ -127,6 +128,12 @@ def __init__( if model is not None or model_args is not None or model_cls is not None: self.model = model if model is not None else dy.eval(model_cls)(**(model_args or dict())) + # load pretrained checkpoint + # TODO: add smart loading of state dict, e.g. if the checkpoint has different structure, + # load only the common parts and ignore the rest + if pretrained_state is not None: + self.load_state_dict(torch.load(pretrained_state)["state_dict"]) + @functools.cached_property def __optimizers_is_active_list(self): if self.optimizers_list is None: diff --git a/setup.py b/setup.py index 84772c3..3e552fc 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup( name="lightning_toolbox", packages=find_packages(include=["lightning_toolbox", "lightning_toolbox.*"]), - version="0.0.32", + version="0.0.33", license="MIT", description="A collection of utilities for PyTorch Lightning.", long_description=long_description,