Skip to content

Commit

Permalink
Merge pull request #40 from ArneBinder/scores_for_token_classification
Browse files Browse the repository at this point in the history
labeled span extraction: add scores for predicted spans
  • Loading branch information
ArneBinder authored Jan 22, 2024
2 parents cea6058 + 9bdfd60 commit 2b75e03
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 73 deletions.
10 changes: 6 additions & 4 deletions src/pie_modules/models/simple_token_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Tuple
from typing import MutableMapping, Optional, Tuple, Union

import torch
from pytorch_ie.core import PyTorchIEModel
Expand All @@ -15,7 +15,7 @@
# model inputs / outputs / targets
InputType: TypeAlias = BatchEncoding
OutputType: TypeAlias = TokenClassifierOutput
TargetType: TypeAlias = LongTensor
TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
# step inputs (batch) / outputs (loss)
StepInputType: TypeAlias = Tuple[InputType, TargetType]
StepOutputType: TypeAlias = FloatTensor
Expand Down Expand Up @@ -76,7 +76,7 @@ def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> Ou
inputs_without_special_tokens_mask = {
k: v for k, v in inputs.items() if k != "special_tokens_mask"
}
return self.model(labels=targets, **inputs_without_special_tokens_mask)
return self.model(**inputs_without_special_tokens_mask, **(targets or {}))

def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
# get the max index for each token from the logits
Expand All @@ -89,7 +89,9 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
tags_tensor = tags_tensor.masked_fill(
inputs["special_tokens_mask"] == 1, self.label_pad_id
)
return tags_tensor
probabilities = torch.softmax(outputs.logits, dim=-1)

return {"labels": tags_tensor, "probabilities": probabilities}

def configure_optimizers(self) -> OptimizerLRScheduler:
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, MutableMapping, Optional, Tuple, Union

import torch
from pytorch_ie.core import PyTorchIEModel
Expand All @@ -22,7 +22,7 @@
# model inputs / outputs / targets
InputType: TypeAlias = BatchEncoding
OutputType: TypeAlias = TokenClassifierOutput
TargetType: TypeAlias = LongTensor
TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
# step inputs (batch) / outputs (loss)
StepInputType: TypeAlias = Tuple[InputType, TargetType]
StepOutputType: TypeAlias = FloatTensor
Expand Down Expand Up @@ -141,6 +141,7 @@ def __init__(
self.crf = CRF(num_tags=num_classes, batch_first=True) if use_crf else None

def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
result = {}
logits = outputs.logits
attention_mask = inputs["attention_mask"]
special_tokens_mask = inputs["special_tokens_mask"]
Expand All @@ -159,7 +160,12 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
# set the padding and special tokens to the label_pad_id
mask = attention_mask_bool & ~special_tokens_mask.to(torch.bool)
tags_tensor = tags_tensor.masked_fill(~mask, self.label_pad_id)
return tags_tensor

result["labels"] = tags_tensor
# TODO: is it correct to use this also in the case of the crf?
result["probabilities"] = torch.softmax(logits, dim=-1)

return result

def forward(
self, inputs: InputType, targets: Optional[TargetType] = None
Expand All @@ -177,8 +183,8 @@ def forward(
logits = self.classifier(sequence_output)

loss = None
labels = targets
if labels is not None:
if targets is not None:
labels = targets["labels"]
if self.crf is not None:
# Overwrite the padding labels with ignore_index. Note that this is different from the
# attention_mask, because the attention_mask includes special tokens, whereas the labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
-> Document
"""

import copy
import logging
from typing import (
Any,
Expand All @@ -19,6 +18,7 @@
Set,
Tuple,
Type,
TypedDict,
Union,
)

Expand All @@ -45,6 +45,8 @@
TokenDocumentWithLabeledSpans,
TokenDocumentWithLabeledSpansAndLabeledPartitions,
)
from pie_modules.models.simple_token_classification import InputType as ModelInputType
from pie_modules.models.simple_token_classification import TargetType as ModelTargetType
from pie_modules.taskmodules.metrics import (
PrecisionRecallAndF1ForLabeledAnnotations,
WrappedMetricWithPrepareFunction,
Expand All @@ -61,11 +63,16 @@
TargetEncodingType,
]
ModelStepInputType: TypeAlias = Tuple[
Dict[str, torch.LongTensor],
Optional[torch.LongTensor],
ModelInputType,
Optional[ModelTargetType],
]
ModelOutputType: TypeAlias = torch.LongTensor
TaskOutputType: TypeAlias = torch.LongTensor
ModelOutputType: TypeAlias = ModelTargetType


class TaskOutputType(TypedDict, total=False):
labels: torch.LongTensor
probabilities: torch.FloatTensor


TaskModuleType: TypeAlias = TaskModule[
DocumentType,
Expand Down Expand Up @@ -305,12 +312,22 @@ def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelStepInputT
pad_mask = inputs["input_ids"] == self.tokenizer.pad_token_id
targets[pad_mask] = self.label_pad_id

return inputs, targets
return inputs, {"labels": targets}

def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]:
return [labels for labels in model_output.detach().cpu()]

def decode_annotations(self, labels: torch.LongTensor) -> Dict[str, Sequence[LabeledSpan]]:
labels = model_output["labels"]
probabilities = model_output.get("probabilities", None)
batch_size = labels.shape[0]
task_outputs: List[TaskOutputType] = []
for batch_idx in range(batch_size):
task_output: TaskOutputType = {"labels": labels[batch_idx]}
if probabilities is not None:
task_output["probabilities"] = probabilities[batch_idx]
task_outputs.append(task_output)
return task_outputs

def decode_annotations(self, encoding: TaskOutputType) -> Dict[str, Sequence[LabeledSpan]]:
labels = encoding["labels"]
tag_sequence = [
"O" if tag_id == self.label_pad_id else self.id_to_label[tag_id]
for tag_id in labels.tolist()
Expand All @@ -319,7 +336,19 @@ def decode_annotations(self, labels: torch.LongTensor) -> Dict[str, Sequence[Lab
for label, (start, end_inclusive) in bio_tags_to_spans(
tag_sequence, include_ill_formed=self.include_ill_formed_predictions
):
labeled_span = LabeledSpan(label=label, start=start, end=end_inclusive + 1)
end = end_inclusive + 1
# do not set the score if the probabilities are not available
annotation_kwargs = {}
if encoding.get("probabilities") is not None:
span_probabilities = encoding["probabilities"][start:end]
span_label_ids = labels[start:end]
# get the probabilities at the label indices
span_label_probs = torch.stack(
[span_probabilities[i, l] for i, l in enumerate(span_label_ids)]
)
# use mean probability of the span as score
annotation_kwargs["score"] = span_label_probs.mean().item()
labeled_span = LabeledSpan(label=label, start=start, end=end, **annotation_kwargs)
labeled_spans.append(labeled_span)
return {"labeled_spans": labeled_spans}

Expand Down Expand Up @@ -356,7 +385,8 @@ def create_annotations_from_output(
yield self.span_annotation, span.copy()

def configure_model_metric(self, stage: str) -> Union[Metric, MetricCollection]:
def remove_label_pad_ids(labels: torch.LongTensor) -> torch.LongTensor:
def remove_label_pad_ids(model_output: ModelOutputType) -> torch.LongTensor:
labels = model_output["labels"]
# remove the special tokens and padding from the predicted / target labels
# because the label_pad_id is usually not a valid index (e.g. -100)
mask = labels != self.label_pad_id
Expand Down
60 changes: 50 additions & 10 deletions tests/models/test_simple_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,16 @@ def batch():
]
),
}
targets = torch.tensor(
[
[-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100],
[-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100],
]
)
targets = {
"labels": torch.tensor(
[
[-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100],
[-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100],
]
)
}
return inputs, targets


Expand Down Expand Up @@ -325,9 +327,11 @@ def test_predict_and_predict_step(model, batch, config, test_step):
predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0)
else:
predictions = model.predict(batch[0])
assert predictions.shape == batch[1].shape
assert set(predictions) == {"labels", "probabilities"}

assert predictions["labels"].shape == batch[1]["labels"].shape
torch.testing.assert_close(
predictions,
predictions["labels"],
torch.tensor(
[
[-100, 3, 3, 4, 3, -100, -100, -100, -100, -100, -100, -100],
Expand All @@ -337,6 +341,42 @@ def test_predict_and_predict_step(model, batch, config, test_step):
]
),
)
torch.testing.assert_close(
# just check the first two batch entries
predictions["probabilities"][:2].round(decimals=4),
torch.tensor(
[
[
[0.2174, 0.1566, 0.1572, 0.2730, 0.1958],
[0.2122, 0.1639, 0.1534, 0.2588, 0.2118],
[0.2025, 0.1550, 0.1639, 0.2435, 0.2350],
[0.2068, 0.1484, 0.1741, 0.2110, 0.2597],
[0.2070, 0.1549, 0.1788, 0.2574, 0.2020],
[0.1853, 0.1586, 0.1807, 0.2530, 0.2224],
[0.2037, 0.1738, 0.1533, 0.2833, 0.1857],
[0.2103, 0.1722, 0.1607, 0.2718, 0.1850],
[0.2280, 0.1564, 0.1881, 0.2386, 0.1888],
[0.2235, 0.1511, 0.1921, 0.2480, 0.1853],
[0.2140, 0.1564, 0.1934, 0.2585, 0.1776],
[0.2092, 0.1722, 0.1886, 0.2526, 0.1774],
],
[
[0.2065, 0.1866, 0.1883, 0.2549, 0.1637],
[0.2104, 0.1639, 0.2123, 0.2289, 0.1845],
[0.2240, 0.1775, 0.2206, 0.2265, 0.1515],
[0.2035, 0.1432, 0.2320, 0.2097, 0.2116],
[0.2158, 0.1984, 0.2031, 0.2141, 0.1685],
[0.2068, 0.1957, 0.2243, 0.2027, 0.1705],
[0.2199, 0.1799, 0.2423, 0.1896, 0.1682],
[0.2221, 0.1514, 0.2504, 0.2057, 0.1703],
[0.1869, 0.1378, 0.2121, 0.2749, 0.1883],
[0.1762, 0.1422, 0.2079, 0.2629, 0.2107],
[0.1927, 0.1553, 0.1657, 0.3043, 0.1819],
[0.1913, 0.1820, 0.1772, 0.2716, 0.1779],
],
]
),
)


def test_configure_optimizers(model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ def batch():
]
),
}
targets = torch.tensor(
[
[-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100],
[-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100],
]
)
targets = {
"labels": torch.tensor(
[
[-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100],
[-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100],
[-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100],
]
)
}
return inputs, targets


Expand Down Expand Up @@ -157,7 +159,7 @@ def test_tune_base_model():
def test_forward(batch, model):
inputs, targets = batch
batch_size, seq_len = inputs["input_ids"].shape
num_classes = int(torch.max(targets) + 1)
num_classes = int(torch.max(targets["labels"]) + 1)

# set seed to make sure the output is deterministic
torch.manual_seed(42)
Expand Down Expand Up @@ -439,10 +441,13 @@ def test_predict_and_predict_step(model, batch, config, test_step):
predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0)
else:
predictions = model.predict(batch[0])
assert predictions.shape == batch[1].shape

assert set(predictions) == {"labels", "probabilities"}
labels = predictions["labels"]
probabilities = predictions["probabilities"]
if config == {}:
torch.testing.assert_close(
predictions,
labels,
torch.tensor(
[
[-100, 1, 3, 1, 1, -100, -100, -100, -100, -100, -100, -100],
Expand All @@ -454,7 +459,7 @@ def test_predict_and_predict_step(model, batch, config, test_step):
)
elif config == {"use_crf": False}:
torch.testing.assert_close(
predictions,
labels,
torch.tensor(
[
[-100, 1, 3, 1, 1, -100, -100, -100, -100, -100, -100, -100],
Expand All @@ -467,6 +472,43 @@ def test_predict_and_predict_step(model, batch, config, test_step):
else:
raise ValueError(f"Unknown config: {config}")

assert labels.shape == batch[1]["labels"].shape
torch.testing.assert_close(
probabilities[:2].round(decimals=4),
torch.tensor(
[
[
[0.2123, 0.2090, 0.1691, 0.1896, 0.2199],
[0.1835, 0.2382, 0.1678, 0.2175, 0.1929],
[0.1997, 0.2078, 0.1597, 0.3080, 0.1247],
[0.1521, 0.2844, 0.2405, 0.1705, 0.1525],
[0.1523, 0.2406, 0.2073, 0.1842, 0.2155],
[0.2048, 0.1966, 0.1860, 0.2822, 0.1305],
[0.1997, 0.1635, 0.2037, 0.2107, 0.2223],
[0.1904, 0.2195, 0.1675, 0.2245, 0.1981],
[0.1834, 0.2070, 0.1912, 0.2497, 0.1688],
[0.1831, 0.1971, 0.1886, 0.2719, 0.1593],
[0.2021, 0.1710, 0.1825, 0.2984, 0.1459],
[0.2090, 0.1694, 0.1492, 0.3119, 0.1605],
],
[
[0.1950, 0.1239, 0.2854, 0.2325, 0.1632],
[0.2324, 0.1133, 0.1760, 0.2818, 0.1965],
[0.1906, 0.1211, 0.2027, 0.2170, 0.2687],
[0.2018, 0.1164, 0.2073, 0.2418, 0.2327],
[0.2354, 0.0762, 0.2061, 0.2774, 0.2050],
[0.1968, 0.0876, 0.2437, 0.3027, 0.1693],
[0.2154, 0.0789, 0.2183, 0.3195, 0.1680],
[0.2011, 0.0958, 0.2537, 0.2560, 0.1934],
[0.1979, 0.1001, 0.2898, 0.2209, 0.1913],
[0.2338, 0.0861, 0.2225, 0.3663, 0.0913],
[0.2280, 0.0760, 0.2654, 0.2864, 0.1441],
[0.2413, 0.0705, 0.2240, 0.2984, 0.1658],
],
]
),
)


def test_configure_optimizers(model):
model.trainer = Trainer(max_epochs=10)
Expand Down
Loading

0 comments on commit 2b75e03

Please sign in to comment.