Skip to content

Commit

Permalink
GH-3536: fix state dict key mismatch for embeddings in TextPairRegres…
Browse files Browse the repository at this point in the history
…sor. this was causing a bug where the model failed to load from the output state dict
  • Loading branch information
MattGPT-ai authored and helpmefindaname committed Aug 23, 2024
1 parent 3d8f078 commit a458f5b
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions flair/models/pairwise_regression_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -91,7 +90,7 @@ def label_type(self):

def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
) -> typing.Iterable[List[str]]:
) -> Iterable[List[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)]
Expand Down Expand Up @@ -204,10 +203,16 @@ def _get_state_dict(self):
return model_state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
# add DefaultClassifier arguments
def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
"""Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict).
Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary.
"""
if "document_embeddings" in state:
state["embeddings"] = state.pop("document_embeddings") # need to rename this parameter
# add Model arguments
for arg in [
"document_embeddings",
"embeddings",
"label_type",
"embed_separately",
"dropout",
Expand Down

0 comments on commit a458f5b

Please sign in to comment.