From a458f5b429cb52294e302f91a53c5e9f69ddb72e Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Thu, 22 Aug 2024 15:49:36 -0700 Subject: [PATCH] GH-3536: fix state dict key mismatch for embeddings in TextPairRegressor. this was causing a bug where the model failed to load from the output state dict --- flair/models/pairwise_regression_model.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index 39a3192aa..c67d81ca1 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -1,6 +1,5 @@ -import typing from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -91,7 +90,7 @@ def label_type(self): def get_used_tokens( self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True - ) -> typing.Iterable[List[str]]: + ) -> Iterable[List[str]]: for sentence_pair in _iter_dataset(corpus.get_all_sentences()): yield [t.text for t in sentence_pair.first] yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)] @@ -204,10 +203,16 @@ def _get_state_dict(self): return model_state @classmethod - def _init_model_with_state_dict(cls, state, **kwargs): - # add DefaultClassifier arguments + def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs): + """Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict). + + Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary. + """ + if "document_embeddings" in state: + state["embeddings"] = state.pop("document_embeddings") # need to rename this parameter + # add Model arguments for arg in [ - "document_embeddings", + "embeddings", "label_type", "embed_separately", "dropout",