From 91eafe4d4ea5194472c472f08031d0f434bc1eca Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Thu, 18 Jul 2024 18:12:13 +0200 Subject: [PATCH 1/3] functionality to use pretrained pie model --- configs/train.yaml | 3 +++ src/train.py | 24 +++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/configs/train.yaml b/configs/train.yaml index 699f9df..abf5af6 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -61,6 +61,9 @@ monitor_mode: "max" # seed for random number generators in pytorch, numpy and python.random seed: null +# path to pretrained pytorch-ie model that updates the weights of base model with pretrained pie model +pretrained_pie_model_path: null + # simply provide checkpoint path to resume training ckpt_path: null diff --git a/src/train.py b/src/train.py index 88ef652..f0cb36c 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,5 @@ import pyrootutils +from pytorch_ie import AutoModel root = pyrootutils.setup_root( search_from=__file__, @@ -40,7 +41,9 @@ import pytorch_lightning as pl from omegaconf import DictConfig from pie_datasets import DatasetDict +from pie_modules.models import * # noqa: F403 from pie_modules.models.interface import RequiresTaskmoduleConfig +from pie_modules.taskmodules import * # noqa: F403 from pytorch_ie.core import PyTorchIEModel, TaskModule from pytorch_ie.models import * # noqa: F403 from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses @@ -140,9 +143,28 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: # initialize the model model: PyTorchIEModel = hydra.utils.instantiate( - cfg.model, _convert_="partial", **additional_model_kwargs + cfg.model, + _convert_="partial", + is_from_pretrained=cfg.get("pretrained_pie_model_path", None) is not None, + **additional_model_kwargs, ) + if "pretrained_pie_model_path" in cfg: + pie_model = AutoModel.from_pretrained(cfg["pretrained_pie_model_path"]) + loaded_state_dict = pie_model.state_dict() + if "pretrained_pie_model_prefix_mapping" in cfg: + state_dict_to_load = {} + for prefix_from, prefix_to in cfg["pretrained_pie_model_prefix_mapping"].items(): + for name, value in loaded_state_dict.items(): + if name.startswith(prefix_from): + new_name = prefix_to + name[len(prefix_from) :] + state_dict_to_load[new_name] = value + else: + state_dict_to_load = loaded_state_dict + model.load_state_dict( + state_dict_to_load, strict=("pretrained_pie_model_prefix_mapping" not in cfg) + ) + log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks") From ef033ca1fcac6d4d419ed22126e013f8573beff7 Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Fri, 19 Jul 2024 12:14:07 +0200 Subject: [PATCH 2/3] fix NoneType error --- src/train.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/train.py b/src/train.py index f0cb36c..d1f410b 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,4 @@ import pyrootutils -from pytorch_ie import AutoModel root = pyrootutils.setup_root( search_from=__file__, @@ -44,6 +43,7 @@ from pie_modules.models import * # noqa: F403 from pie_modules.models.interface import RequiresTaskmoduleConfig from pie_modules.taskmodules import * # noqa: F403 +from pytorch_ie import AutoModel from pytorch_ie.core import PyTorchIEModel, TaskModule from pytorch_ie.models import * # noqa: F403 from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses @@ -149,10 +149,11 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: **additional_model_kwargs, ) - if "pretrained_pie_model_path" in cfg: + if cfg.get("pretrained_pie_model_path", None) is not None: pie_model = AutoModel.from_pretrained(cfg["pretrained_pie_model_path"]) loaded_state_dict = pie_model.state_dict() - if "pretrained_pie_model_prefix_mapping" in cfg: + has_prefix_mapping = cfg.get("pretrained_pie_model_prefix_mapping", None) is not None + if has_prefix_mapping: state_dict_to_load = {} for prefix_from, prefix_to in cfg["pretrained_pie_model_prefix_mapping"].items(): for name, value in loaded_state_dict.items(): @@ -161,9 +162,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: state_dict_to_load[new_name] = value else: state_dict_to_load = loaded_state_dict - model.load_state_dict( - state_dict_to_load, strict=("pretrained_pie_model_prefix_mapping" not in cfg) - ) + model.load_state_dict(state_dict_to_load, strict=not has_prefix_mapping) log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks") From 81d0b3c9f72b71850b3756f51b34ac338e1a9cad Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Fri, 19 Jul 2024 14:19:46 +0200 Subject: [PATCH 3/3] add documentation to src/train.py --- src/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/train.py b/src/train.py index d1f410b..ebe26e9 100644 --- a/src/train.py +++ b/src/train.py @@ -145,6 +145,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: model: PyTorchIEModel = hydra.utils.instantiate( cfg.model, _convert_="partial", + # In the case of loading weights from a pretrained PIE model, we do not need to download the base (transformer) model in the model constructors. We disable that by passing is_from_pretrained=True in these cases. is_from_pretrained=cfg.get("pretrained_pie_model_path", None) is not None, **additional_model_kwargs, )