Skip to content

Commit

Permalink
functionality to use pretrained pie model
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhuvanesh-Verma committed Jul 18, 2024
1 parent fc55427 commit 91eafe4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
3 changes: 3 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ monitor_mode: "max"
# seed for random number generators in pytorch, numpy and python.random
seed: null

# path to pretrained pytorch-ie model that updates the weights of base model with pretrained pie model
pretrained_pie_model_path: null

# simply provide checkpoint path to resume training
ckpt_path: null

Expand Down
24 changes: 23 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pyrootutils
from pytorch_ie import AutoModel

root = pyrootutils.setup_root(
search_from=__file__,
Expand Down Expand Up @@ -40,7 +41,9 @@
import pytorch_lightning as pl
from omegaconf import DictConfig
from pie_datasets import DatasetDict
from pie_modules.models import * # noqa: F403
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pie_modules.taskmodules import * # noqa: F403
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
Expand Down Expand Up @@ -140,9 +143,28 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:

# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
cfg.model,
_convert_="partial",
is_from_pretrained=cfg.get("pretrained_pie_model_path", None) is not None,
**additional_model_kwargs,
)

if "pretrained_pie_model_path" in cfg:
pie_model = AutoModel.from_pretrained(cfg["pretrained_pie_model_path"])
loaded_state_dict = pie_model.state_dict()
if "pretrained_pie_model_prefix_mapping" in cfg:
state_dict_to_load = {}
for prefix_from, prefix_to in cfg["pretrained_pie_model_prefix_mapping"].items():
for name, value in loaded_state_dict.items():
if name.startswith(prefix_from):
new_name = prefix_to + name[len(prefix_from) :]
state_dict_to_load[new_name] = value
else:
state_dict_to_load = loaded_state_dict
model.load_state_dict(
state_dict_to_load, strict=("pretrained_pie_model_prefix_mapping" not in cfg)
)

log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")

Expand Down

0 comments on commit 91eafe4

Please sign in to comment.