Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add chunking function for sequence tagger training on sentences exceeding token limit #3520

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flair/class_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import importlib
import inspect
from types import ModuleType
from typing import Any, Iterable, List, Optional, Type, TypeVar, Union, overload
from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload

T = TypeVar("T")


class StringLike(Protocol):
def __str__(self) -> str: ...


def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
for subclass in cls.__subclasses__():
yield from get_non_abstract_subclasses(subclass)
Expand Down
224 changes: 181 additions & 43 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
import logging
import pathlib
import random
from collections import defaultdict
from enum import Enum
from functools import reduce
from math import inf
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, NamedTuple, Optional, Union

from numpy import ndarray
from scipy.stats import pearsonr, spearmanr
from scipy.stats._stats_py import PearsonRResult, SignificanceResult
from sklearn.metrics import mean_absolute_error, mean_squared_error
from torch.optim import Optimizer
from torch.utils.data import Dataset

import flair
from flair.data import DT, Dictionary, Sentence, _iter_dataset
from flair.class_utils import StringLike
from flair.data import DT, Dictionary, Sentence, Token, _iter_dataset

log = logging.getLogger("flair")
MinMax = Literal["min", "max"]
logger = logging.getLogger("flair")


class Result:
def __init__(
self,
main_score: float,
detailed_results: str,
classification_report: dict = {},
scores: dict = {},
classification_report: Optional[Dict] = None,
scores: Optional[Dict] = None,
) -> None:
assert "loss" in scores, "No loss provided."
assert scores is not None and "loss" in scores, "No loss provided."

self.main_score: float = main_score
self.scores = scores
self.detailed_results: str = detailed_results
self.classification_report = classification_report
self.classification_report = classification_report if classification_report is not None else {}

@property
def loss(self):
Expand All @@ -42,40 +47,36 @@ def __str__(self) -> str:


class MetricRegression:
def __init__(self, name) -> None:
def __init__(self, name: str) -> None:
self.name = name

self.true: List[float] = []
self.pred: List[float] = []

def mean_squared_error(self):
def mean_squared_error(self) -> Union[float, ndarray]:
return mean_squared_error(self.true, self.pred)

def mean_absolute_error(self):
return mean_absolute_error(self.true, self.pred)

def pearsonr(self):
def pearsonr(self) -> PearsonRResult:
return pearsonr(self.true, self.pred)[0]

def spearmanr(self):
def spearmanr(self) -> SignificanceResult:
return spearmanr(self.true, self.pred)[0]

# dummy return to fulfill trainer.train() needs
def micro_avg_f_score(self):
return self.mean_squared_error()

def to_tsv(self):
def to_tsv(self) -> str:
return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}"

@staticmethod
def tsv_header(prefix=None):
def tsv_header(prefix: StringLike = None) -> str:
if prefix:
return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN"

return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN"

@staticmethod
def to_empty_tsv():
def to_empty_tsv() -> str:
return "\t_\t_\t_\t_"

def __str__(self) -> str:
Expand All @@ -99,13 +100,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) ->
self.weights_dict: Dict[str, Dict[int, List[float]]] = defaultdict(lambda: defaultdict(list))
self.number_of_weights = number_of_weights

def extract_weights(self, state_dict, iteration):
def extract_weights(self, state_dict: Dict, iteration: int) -> None:
for key in state_dict:
vec = state_dict[key]
# print(vec)
try:
weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size())))
except Exception:
except Exception as e:
logger.debug(e)
continue

if key not in self.weights_dict:
Expand Down Expand Up @@ -193,15 +194,15 @@ class AnnealOnPlateau:
def __init__(
self,
optimizer,
mode="min",
aux_mode="min",
factor=0.1,
patience=10,
initial_extra_patience=0,
verbose=False,
cooldown=0,
min_lr=0,
eps=1e-8,
mode: MinMax = "min",
aux_mode: MinMax = "min",
factor: float = 0.1,
patience: int = 10,
initial_extra_patience: int = 0,
verbose: bool = False,
cooldown: int = 0,
min_lr: float = 0.0,
eps: float = 1e-8,
) -> None:
if factor >= 1.0:
raise ValueError("Factor should be < 1.0.")
Expand All @@ -212,6 +213,7 @@ def __init__(
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
self.optimizer = optimizer

self.min_lrs: List[float]
if isinstance(min_lr, (list, tuple)):
if len(min_lr) != len(optimizer.param_groups):
raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}")
Expand All @@ -229,7 +231,7 @@ def __init__(
self.best = None
self.best_aux = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.mode_worse: Optional[float] = None # the worse value for the chosen mode
self.eps = eps
self.last_epoch = 0
self._init_is_better(mode=mode)
Expand All @@ -256,7 +258,7 @@ def step(self, metric, auxiliary_metric=None) -> bool:
if self.mode == "max" and current > self.best:
is_better = True

if current == self.best and auxiliary_metric:
if current == self.best and auxiliary_metric is not None:
current_aux = float(auxiliary_metric)
if self.aux_mode == "min" and current_aux < self.best_aux:
is_better = True
Expand Down Expand Up @@ -287,20 +289,20 @@ def step(self, metric, auxiliary_metric=None) -> bool:

return reduce_learning_rate

def _reduce_lr(self, epoch):
def _reduce_lr(self, epoch: int) -> None:
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.factor, self.min_lrs[i])
if old_lr - new_lr > self.eps:
param_group["lr"] = new_lr
if self.verbose:
log.info(f" - reducing learning rate of group {epoch} to {new_lr}")
logger.info(f" - reducing learning rate of group {epoch} to {new_lr}")

@property
def in_cooldown(self):
return self.cooldown_counter > 0

def _init_is_better(self, mode):
def _init_is_better(self, mode: MinMax) -> None:
if mode not in {"min", "max"}:
raise ValueError("mode " + mode + " is unknown!")

Expand All @@ -311,10 +313,10 @@ def _init_is_better(self, mode):

self.mode = mode

def state_dict(self):
def state_dict(self) -> Dict:
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}

def load_state_dict(self, state_dict):
def load_state_dict(self, state_dict: Dict) -> None:
self.__dict__.update(state_dict)
self._init_is_better(mode=self.mode)

Expand Down Expand Up @@ -348,11 +350,11 @@ def convert_labels_to_one_hot(label_list: List[List[str]], label_dict: Dictionar
return [[1 if label in labels else 0 for label in label_dict.get_items()] for labels in label_list]


def log_line(log):
def log_line(log: logging.Logger) -> None:
log.info("-" * 100, stacklevel=3)


def add_file_handler(log, output_file):
def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.FileHandler:
init_output_file(output_file.parents[0], output_file.name)
fh = logging.FileHandler(output_file, mode="w", encoding="utf-8")
fh.setLevel(logging.INFO)
Expand All @@ -363,12 +365,21 @@ def add_file_handler(log, output_file):


def store_embeddings(
data_points: Union[List[DT], Dataset], storage_mode: str, dynamic_embeddings: Optional[List[str]] = None
):
data_points: Union[List[DT], Dataset],
storage_mode: str,
dynamic_embeddings: Optional[List[str]] = None,
) -> None:
"""Stores embeddings of data points in memory or on disk.

Args:
data_points: a DataSet or list of DataPoints for which embeddings should be stored
storage_mode: store in either CPU or GPU memory, or delete them if set to 'none'
dynamic_embeddings: these are always deleted. If not passed, they are identified automatically.
"""
if isinstance(data_points, Dataset):
data_points = list(_iter_dataset(data_points))

# if memory mode option 'none' delete everything
# if storage mode option 'none' delete everything
if storage_mode == "none":
dynamic_embeddings = None

Expand All @@ -387,7 +398,7 @@ def store_embeddings(
data_point.to("cpu", pin_memory=pin_memory)


def identify_dynamic_embeddings(data_points: List[DT]):
def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]:
dynamic_embeddings = []
all_embeddings = []
for data_point in data_points:
Expand All @@ -407,3 +418,130 @@ def identify_dynamic_embeddings(data_points: List[DT]):
if not all_embeddings:
return None
return list(set(dynamic_embeddings))


class TokenEntity(NamedTuple):
"""Entity represented by token indices."""

start_token_idx: int
end_token_idx: int
label: str
value: str = "" # text value of the entity
score: float = 1.0


class CharEntity(NamedTuple):
"""Entity represented by character indices."""

start_char_idx: int
end_char_idx: int
label: str
value: str
score: float = 1.0


def create_labeled_sentence_from_tokens(
tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner"
) -> Sentence:
"""Creates a new Sentence object from a list of tokens or strings and applies entity labels.

Tokens are recreated with the same text, but not attached to the previous sentence.

Args:
tokens: a list of Token objects or strings - only the text is used, not any labels
token_entities: a list of TokenEntity objects representing entity annotations
type_name: the type of entity label to apply
Returns:
A labeled Sentence object
"""
tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence
sentence = Sentence(tokens, use_tokenizer=True)
for entity in token_entities:
sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score)
return sentence


def create_sentence_chunks(
text: str,
entities: List[CharEntity],
token_limit: int = 512,
use_context: bool = True,
overlap: int = 0, # TODO: implement overlap
) -> List[Sentence]:
"""Chunks and labels a text from a list of entity annotations.

The function explicitly tokenizes the text and labels separately, ensuring entity labels are
not partially split across tokens.

Args:
text (str): The full text to be tokenized and labeled.
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
format (start_char_index, end_char_index, entity_class, entity_text).
token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
use_context: whether to add context to the sentence
overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context

Returns:
A list of labeled Sentence objects representing the chunks of the original text
"""
chunks = []

tokens: List[Token] = []
current_index = 0
token_entities: List[TokenEntity] = []
end_token_idx = 0

for entity in entities:

if entity.start_char_idx > current_index: # add non-entity text
non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens
while end_token_idx + len(non_entity_tokens) > token_limit:
num_tokens = token_limit - len(tokens)
tokens.extend(non_entity_tokens[:num_tokens])
non_entity_tokens = non_entity_tokens[num_tokens:]
# skip any fully negative samples, they cause fine_tune to fail with
# `torch.cat(): expected a non-empty list of Tensors`
if len(token_entities) > 0:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
tokens, token_entities = [], []
end_token_idx = 0
tokens.extend(non_entity_tokens)

# add new entity tokens
start_token_idx = len(tokens)
entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx])
if len(entity_sentence) > token_limit:
logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}")
end_token_idx = start_token_idx + len(entity_sentence)

if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))

tokens, token_entities = [], []
start_token_idx, end_token_idx = 0, len(entity_sentence)

token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score)
token_entities.append(token_entity)
tokens.extend(entity_sentence)

current_index = entity.end_char_idx

# add any remaining tokens to a new chunk
if current_index < len(text):
remaining_sentence = Sentence(text[current_index:])
if end_token_idx + len(remaining_sentence) > token_limit:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
tokens, token_entities = [], []
tokens.extend(remaining_sentence)

if tokens:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))

for chunk in chunks:
if len(chunk) > token_limit:
logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}")

if use_context:
Sentence.set_context_for_sentences(chunks)

return chunks
Loading
Loading