From 91eafe4d4ea5194472c472f08031d0f434bc1eca Mon Sep 17 00:00:00 2001 From: Bhuvanesh Verma Date: Thu, 18 Jul 2024 18:12:13 +0200 Subject: [PATCH] functionality to use pretrained pie model --- configs/train.yaml | 3 +++ src/train.py | 24 +++++++++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/configs/train.yaml b/configs/train.yaml index 699f9df..abf5af6 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -61,6 +61,9 @@ monitor_mode: "max" # seed for random number generators in pytorch, numpy and python.random seed: null +# path to pretrained pytorch-ie model that updates the weights of base model with pretrained pie model +pretrained_pie_model_path: null + # simply provide checkpoint path to resume training ckpt_path: null diff --git a/src/train.py b/src/train.py index 88ef652..f0cb36c 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,5 @@ import pyrootutils +from pytorch_ie import AutoModel root = pyrootutils.setup_root( search_from=__file__, @@ -40,7 +41,9 @@ import pytorch_lightning as pl from omegaconf import DictConfig from pie_datasets import DatasetDict +from pie_modules.models import * # noqa: F403 from pie_modules.models.interface import RequiresTaskmoduleConfig +from pie_modules.taskmodules import * # noqa: F403 from pytorch_ie.core import PyTorchIEModel, TaskModule from pytorch_ie.models import * # noqa: F403 from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses @@ -140,9 +143,28 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: # initialize the model model: PyTorchIEModel = hydra.utils.instantiate( - cfg.model, _convert_="partial", **additional_model_kwargs + cfg.model, + _convert_="partial", + is_from_pretrained=cfg.get("pretrained_pie_model_path", None) is not None, + **additional_model_kwargs, ) + if "pretrained_pie_model_path" in cfg: + pie_model = AutoModel.from_pretrained(cfg["pretrained_pie_model_path"]) + loaded_state_dict = pie_model.state_dict() + if "pretrained_pie_model_prefix_mapping" in cfg: + state_dict_to_load = {} + for prefix_from, prefix_to in cfg["pretrained_pie_model_prefix_mapping"].items(): + for name, value in loaded_state_dict.items(): + if name.startswith(prefix_from): + new_name = prefix_to + name[len(prefix_from) :] + state_dict_to_load[new_name] = value + else: + state_dict_to_load = loaded_state_dict + model.load_state_dict( + state_dict_to_load, strict=("pretrained_pie_model_prefix_mapping" not in cfg) + ) + log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")