From 9bdfd6059426b2f7b5f8eb404309b274896f0f73 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 22 Jan 2024 19:25:05 +0100 Subject: [PATCH] use mean prediction probabilities for scores of LabeledSpans --- .../models/simple_token_classification.py | 10 ++- ...sification_with_seq2seq_encoder_and_crf.py | 16 ++-- ...span_extraction_by_token_classification.py | 52 +++++++++--- .../test_simple_token_classification.py | 60 ++++++++++--- ...sification_with_seq2seq_encoder_and_crf.py | 66 +++++++++++--- ...span_extraction_by_token_classification.py | 85 ++++++++++++------- 6 files changed, 216 insertions(+), 73 deletions(-) diff --git a/src/pie_modules/models/simple_token_classification.py b/src/pie_modules/models/simple_token_classification.py index b904c1309..c36a44341 100644 --- a/src/pie_modules/models/simple_token_classification.py +++ b/src/pie_modules/models/simple_token_classification.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Tuple +from typing import MutableMapping, Optional, Tuple, Union import torch from pytorch_ie.core import PyTorchIEModel @@ -15,7 +15,7 @@ # model inputs / outputs / targets InputType: TypeAlias = BatchEncoding OutputType: TypeAlias = TokenClassifierOutput -TargetType: TypeAlias = LongTensor +TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] # step inputs (batch) / outputs (loss) StepInputType: TypeAlias = Tuple[InputType, TargetType] StepOutputType: TypeAlias = FloatTensor @@ -76,7 +76,7 @@ def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> Ou inputs_without_special_tokens_mask = { k: v for k, v in inputs.items() if k != "special_tokens_mask" } - return self.model(labels=targets, **inputs_without_special_tokens_mask) + return self.model(**inputs_without_special_tokens_mask, **(targets or {})) def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: # get the max index for each token from the logits @@ -89,7 +89,9 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: tags_tensor = tags_tensor.masked_fill( inputs["special_tokens_mask"] == 1, self.label_pad_id ) - return tags_tensor + probabilities = torch.softmax(outputs.logits, dim=-1) + + return {"labels": tags_tensor, "probabilities": probabilities} def configure_optimizers(self) -> OptimizerLRScheduler: return torch.optim.Adam(self.parameters(), lr=self.learning_rate) diff --git a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py b/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py index 7eed9d44b..761041aa3 100644 --- a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py +++ b/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, MutableMapping, Optional, Tuple, Union import torch from pytorch_ie.core import PyTorchIEModel @@ -22,7 +22,7 @@ # model inputs / outputs / targets InputType: TypeAlias = BatchEncoding OutputType: TypeAlias = TokenClassifierOutput -TargetType: TypeAlias = LongTensor +TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] # step inputs (batch) / outputs (loss) StepInputType: TypeAlias = Tuple[InputType, TargetType] StepOutputType: TypeAlias = FloatTensor @@ -141,6 +141,7 @@ def __init__( self.crf = CRF(num_tags=num_classes, batch_first=True) if use_crf else None def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: + result = {} logits = outputs.logits attention_mask = inputs["attention_mask"] special_tokens_mask = inputs["special_tokens_mask"] @@ -159,7 +160,12 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: # set the padding and special tokens to the label_pad_id mask = attention_mask_bool & ~special_tokens_mask.to(torch.bool) tags_tensor = tags_tensor.masked_fill(~mask, self.label_pad_id) - return tags_tensor + + result["labels"] = tags_tensor + # TODO: is it correct to use this also in the case of the crf? + result["probabilities"] = torch.softmax(logits, dim=-1) + + return result def forward( self, inputs: InputType, targets: Optional[TargetType] = None @@ -177,8 +183,8 @@ def forward( logits = self.classifier(sequence_output) loss = None - labels = targets - if labels is not None: + if targets is not None: + labels = targets["labels"] if self.crf is not None: # Overwrite the padding labels with ignore_index. Note that this is different from the # attention_mask, because the attention_mask includes special tokens, whereas the labels diff --git a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py index e7aabc5dc..d96d1c471 100644 --- a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py +++ b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py @@ -7,7 +7,6 @@ -> Document """ -import copy import logging from typing import ( Any, @@ -19,6 +18,7 @@ Set, Tuple, Type, + TypedDict, Union, ) @@ -45,6 +45,8 @@ TokenDocumentWithLabeledSpans, TokenDocumentWithLabeledSpansAndLabeledPartitions, ) +from pie_modules.models.simple_token_classification import InputType as ModelInputType +from pie_modules.models.simple_token_classification import TargetType as ModelTargetType from pie_modules.taskmodules.metrics import ( PrecisionRecallAndF1ForLabeledAnnotations, WrappedMetricWithPrepareFunction, @@ -61,11 +63,16 @@ TargetEncodingType, ] ModelStepInputType: TypeAlias = Tuple[ - Dict[str, torch.LongTensor], - Optional[torch.LongTensor], + ModelInputType, + Optional[ModelTargetType], ] -ModelOutputType: TypeAlias = torch.LongTensor -TaskOutputType: TypeAlias = torch.LongTensor +ModelOutputType: TypeAlias = ModelTargetType + + +class TaskOutputType(TypedDict, total=False): + labels: torch.LongTensor + probabilities: torch.FloatTensor + TaskModuleType: TypeAlias = TaskModule[ DocumentType, @@ -305,12 +312,22 @@ def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelStepInputT pad_mask = inputs["input_ids"] == self.tokenizer.pad_token_id targets[pad_mask] = self.label_pad_id - return inputs, targets + return inputs, {"labels": targets} def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]: - return [labels for labels in model_output.detach().cpu()] - - def decode_annotations(self, labels: torch.LongTensor) -> Dict[str, Sequence[LabeledSpan]]: + labels = model_output["labels"] + probabilities = model_output.get("probabilities", None) + batch_size = labels.shape[0] + task_outputs: List[TaskOutputType] = [] + for batch_idx in range(batch_size): + task_output: TaskOutputType = {"labels": labels[batch_idx]} + if probabilities is not None: + task_output["probabilities"] = probabilities[batch_idx] + task_outputs.append(task_output) + return task_outputs + + def decode_annotations(self, encoding: TaskOutputType) -> Dict[str, Sequence[LabeledSpan]]: + labels = encoding["labels"] tag_sequence = [ "O" if tag_id == self.label_pad_id else self.id_to_label[tag_id] for tag_id in labels.tolist() @@ -319,7 +336,19 @@ def decode_annotations(self, labels: torch.LongTensor) -> Dict[str, Sequence[Lab for label, (start, end_inclusive) in bio_tags_to_spans( tag_sequence, include_ill_formed=self.include_ill_formed_predictions ): - labeled_span = LabeledSpan(label=label, start=start, end=end_inclusive + 1) + end = end_inclusive + 1 + # do not set the score if the probabilities are not available + annotation_kwargs = {} + if encoding.get("probabilities") is not None: + span_probabilities = encoding["probabilities"][start:end] + span_label_ids = labels[start:end] + # get the probabilities at the label indices + span_label_probs = torch.stack( + [span_probabilities[i, l] for i, l in enumerate(span_label_ids)] + ) + # use mean probability of the span as score + annotation_kwargs["score"] = span_label_probs.mean().item() + labeled_span = LabeledSpan(label=label, start=start, end=end, **annotation_kwargs) labeled_spans.append(labeled_span) return {"labeled_spans": labeled_spans} @@ -356,7 +385,8 @@ def create_annotations_from_output( yield self.span_annotation, span.copy() def configure_model_metric(self, stage: str) -> Union[Metric, MetricCollection]: - def remove_label_pad_ids(labels: torch.LongTensor) -> torch.LongTensor: + def remove_label_pad_ids(model_output: ModelOutputType) -> torch.LongTensor: + labels = model_output["labels"] # remove the special tokens and padding from the predicted / target labels # because the label_pad_id is usually not a valid index (e.g. -100) mask = labels != self.label_pad_id diff --git a/tests/models/test_simple_token_classification.py b/tests/models/test_simple_token_classification.py index fb8ce39c1..ce3678b7d 100644 --- a/tests/models/test_simple_token_classification.py +++ b/tests/models/test_simple_token_classification.py @@ -90,14 +90,16 @@ def batch(): ] ), } - targets = torch.tensor( - [ - [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100], - [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100], - ] - ) + targets = { + "labels": torch.tensor( + [ + [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100], + [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100], + [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100], + [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100], + ] + ) + } return inputs, targets @@ -325,9 +327,11 @@ def test_predict_and_predict_step(model, batch, config, test_step): predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0) else: predictions = model.predict(batch[0]) - assert predictions.shape == batch[1].shape + assert set(predictions) == {"labels", "probabilities"} + + assert predictions["labels"].shape == batch[1]["labels"].shape torch.testing.assert_close( - predictions, + predictions["labels"], torch.tensor( [ [-100, 3, 3, 4, 3, -100, -100, -100, -100, -100, -100, -100], @@ -337,6 +341,42 @@ def test_predict_and_predict_step(model, batch, config, test_step): ] ), ) + torch.testing.assert_close( + # just check the first two batch entries + predictions["probabilities"][:2].round(decimals=4), + torch.tensor( + [ + [ + [0.2174, 0.1566, 0.1572, 0.2730, 0.1958], + [0.2122, 0.1639, 0.1534, 0.2588, 0.2118], + [0.2025, 0.1550, 0.1639, 0.2435, 0.2350], + [0.2068, 0.1484, 0.1741, 0.2110, 0.2597], + [0.2070, 0.1549, 0.1788, 0.2574, 0.2020], + [0.1853, 0.1586, 0.1807, 0.2530, 0.2224], + [0.2037, 0.1738, 0.1533, 0.2833, 0.1857], + [0.2103, 0.1722, 0.1607, 0.2718, 0.1850], + [0.2280, 0.1564, 0.1881, 0.2386, 0.1888], + [0.2235, 0.1511, 0.1921, 0.2480, 0.1853], + [0.2140, 0.1564, 0.1934, 0.2585, 0.1776], + [0.2092, 0.1722, 0.1886, 0.2526, 0.1774], + ], + [ + [0.2065, 0.1866, 0.1883, 0.2549, 0.1637], + [0.2104, 0.1639, 0.2123, 0.2289, 0.1845], + [0.2240, 0.1775, 0.2206, 0.2265, 0.1515], + [0.2035, 0.1432, 0.2320, 0.2097, 0.2116], + [0.2158, 0.1984, 0.2031, 0.2141, 0.1685], + [0.2068, 0.1957, 0.2243, 0.2027, 0.1705], + [0.2199, 0.1799, 0.2423, 0.1896, 0.1682], + [0.2221, 0.1514, 0.2504, 0.2057, 0.1703], + [0.1869, 0.1378, 0.2121, 0.2749, 0.1883], + [0.1762, 0.1422, 0.2079, 0.2629, 0.2107], + [0.1927, 0.1553, 0.1657, 0.3043, 0.1819], + [0.1913, 0.1820, 0.1772, 0.2716, 0.1779], + ], + ] + ), + ) def test_configure_optimizers(model): diff --git a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py index ef3f849b6..008b91a23 100644 --- a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py +++ b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py @@ -89,14 +89,16 @@ def batch(): ] ), } - targets = torch.tensor( - [ - [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100], - [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100], - ] - ) + targets = { + "labels": torch.tensor( + [ + [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100], + [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100], + [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100], + [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100], + ] + ) + } return inputs, targets @@ -157,7 +159,7 @@ def test_tune_base_model(): def test_forward(batch, model): inputs, targets = batch batch_size, seq_len = inputs["input_ids"].shape - num_classes = int(torch.max(targets) + 1) + num_classes = int(torch.max(targets["labels"]) + 1) # set seed to make sure the output is deterministic torch.manual_seed(42) @@ -439,10 +441,13 @@ def test_predict_and_predict_step(model, batch, config, test_step): predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0) else: predictions = model.predict(batch[0]) - assert predictions.shape == batch[1].shape + + assert set(predictions) == {"labels", "probabilities"} + labels = predictions["labels"] + probabilities = predictions["probabilities"] if config == {}: torch.testing.assert_close( - predictions, + labels, torch.tensor( [ [-100, 1, 3, 1, 1, -100, -100, -100, -100, -100, -100, -100], @@ -454,7 +459,7 @@ def test_predict_and_predict_step(model, batch, config, test_step): ) elif config == {"use_crf": False}: torch.testing.assert_close( - predictions, + labels, torch.tensor( [ [-100, 1, 3, 1, 1, -100, -100, -100, -100, -100, -100, -100], @@ -467,6 +472,43 @@ def test_predict_and_predict_step(model, batch, config, test_step): else: raise ValueError(f"Unknown config: {config}") + assert labels.shape == batch[1]["labels"].shape + torch.testing.assert_close( + probabilities[:2].round(decimals=4), + torch.tensor( + [ + [ + [0.2123, 0.2090, 0.1691, 0.1896, 0.2199], + [0.1835, 0.2382, 0.1678, 0.2175, 0.1929], + [0.1997, 0.2078, 0.1597, 0.3080, 0.1247], + [0.1521, 0.2844, 0.2405, 0.1705, 0.1525], + [0.1523, 0.2406, 0.2073, 0.1842, 0.2155], + [0.2048, 0.1966, 0.1860, 0.2822, 0.1305], + [0.1997, 0.1635, 0.2037, 0.2107, 0.2223], + [0.1904, 0.2195, 0.1675, 0.2245, 0.1981], + [0.1834, 0.2070, 0.1912, 0.2497, 0.1688], + [0.1831, 0.1971, 0.1886, 0.2719, 0.1593], + [0.2021, 0.1710, 0.1825, 0.2984, 0.1459], + [0.2090, 0.1694, 0.1492, 0.3119, 0.1605], + ], + [ + [0.1950, 0.1239, 0.2854, 0.2325, 0.1632], + [0.2324, 0.1133, 0.1760, 0.2818, 0.1965], + [0.1906, 0.1211, 0.2027, 0.2170, 0.2687], + [0.2018, 0.1164, 0.2073, 0.2418, 0.2327], + [0.2354, 0.0762, 0.2061, 0.2774, 0.2050], + [0.1968, 0.0876, 0.2437, 0.3027, 0.1693], + [0.2154, 0.0789, 0.2183, 0.3195, 0.1680], + [0.2011, 0.0958, 0.2537, 0.2560, 0.1934], + [0.1979, 0.1001, 0.2898, 0.2209, 0.1913], + [0.2338, 0.0861, 0.2225, 0.3663, 0.0913], + [0.2280, 0.0760, 0.2654, 0.2864, 0.1441], + [0.2413, 0.0705, 0.2240, 0.2984, 0.1658], + ], + ] + ), + ) + def test_configure_optimizers(model): model.trainer = Trainer(max_epochs=10) diff --git a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py b/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py index 37c03b762..5e3c79375 100644 --- a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py +++ b/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py @@ -16,6 +16,9 @@ from transformers import BatchEncoding from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule +from pie_modules.taskmodules.labeled_span_extraction_by_token_classification import ( + ModelOutputType, +) def _config_to_str(cfg: Dict[str, Any]) -> str: @@ -384,8 +387,9 @@ def test_collate(batch, config): assert set(inputs.data) == {"input_ids", "attention_mask", "special_tokens_mask"} input_ids_list = inputs.input_ids.tolist() attention_mask_list = inputs.attention_mask.tolist() - targets_list = targets.tolist() special_tokens_mask_list = inputs.special_tokens_mask.tolist() + assert set(targets) == {"labels"} + labels_list = targets["labels"].tolist() # If config is empty if config == CONFIG_DEFAULT: @@ -397,7 +401,7 @@ def test_collate(batch, config): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ] - assert targets_list == [ + assert labels_list == [ [-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100], [-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100], ] @@ -420,7 +424,7 @@ def test_collate(batch, config): [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], ] - assert targets_list == [ + assert labels_list == [ [-100, 1, 2, 0, 0, 0, 0, -100], [-100, 0, 0, 0, 0, 0, 0, -100], [-100, 3, 0, 0, 0, 0, 3, -100], @@ -441,7 +445,7 @@ def test_collate(batch, config): [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0], ] - assert targets_list == [ + assert labels_list == [ [-100, 1, 2, 0, 0, 0, 0, -100], [-100, 0, 0, 0, 0, -100, -100, -100], [-100, 3, 0, 0, 0, 0, 3, -100], @@ -458,7 +462,7 @@ def test_collate(batch, config): elif config == CONFIG_PARTITIONS: assert input_ids_list == [[101, 3960, 15646, 2652, 4715, 1012, 102]] assert attention_mask_list == [[1, 1, 1, 1, 1, 1, 1]] - assert targets_list == [[-100, 3, 0, 0, 0, 0, -100]] + assert labels_list == [[-100, 3, 0, 0, 0, 0, -100]] assert special_tokens_mask_list == [[1, 0, 0, 0, 0, 0, 1]] else: @@ -472,8 +476,8 @@ def test_collate(batch, config): } ) assert set(inputs.data) == set(inputs_expected.data) - targets_expected = torch.tensor(targets_list, dtype=torch.int64) - assert torch.equal(targets, targets_expected) + labels_expected = torch.tensor(labels_list, dtype=torch.int64) + assert torch.equal(targets["labels"], labels_expected) # This is not used, but can be used to create a batch of task encodings with targets for the unbatched_outputs fixture. @@ -491,10 +495,18 @@ def real_model_output(batch, taskmodule): @pytest.fixture(scope="module") -def model_output(config, batch, taskmodule) -> torch.LongTensor: +def model_output(config, batch, taskmodule) -> ModelOutputType: # create "perfect" output from targets - targets = batch[1].clone() - return targets + labels = batch[1]["labels"] + num_classes = len(taskmodule.label_to_id) + # create one-hot encoding from labels + labels_valid = labels.clone() + labels_valid[labels_valid == taskmodule.label_pad_id] = taskmodule.label_to_id["O"] + # create one-hot encoding from labels, but with 0.9 for the correct labels + probabilities = ( + torch.nn.functional.one_hot(labels_valid, num_classes=num_classes).to(torch.float32) * 0.9 + ) + return {"labels": labels, "probabilities": probabilities} @pytest.fixture(scope="module") @@ -508,42 +520,46 @@ def test_unbatched_output(unbatched_outputs, config): if config == CONFIG_DEFAULT: assert len(unbatched_outputs) == 2 torch.testing.assert_close( - unbatched_outputs[0], torch.tensor([-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100]) + unbatched_outputs[0]["labels"], + torch.tensor([-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100]), ) torch.testing.assert_close( - unbatched_outputs[1], torch.tensor([-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100]) + unbatched_outputs[1]["labels"], + torch.tensor([-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100]), ) elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: assert len(unbatched_outputs) == 4 torch.testing.assert_close( - unbatched_outputs[0], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100]) + unbatched_outputs[0]["labels"], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100]) ) torch.testing.assert_close( - unbatched_outputs[1], torch.tensor([-100, 0, 0, 0, 0, 0, 0, -100]) + unbatched_outputs[1]["labels"], torch.tensor([-100, 0, 0, 0, 0, 0, 0, -100]) ) torch.testing.assert_close( - unbatched_outputs[2], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100]) + unbatched_outputs[2]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100]) ) torch.testing.assert_close( - unbatched_outputs[3], torch.tensor([-100, 0, 3, 0, 0, 0, 0, -100]) + unbatched_outputs[3]["labels"], torch.tensor([-100, 0, 3, 0, 0, 0, 0, -100]) ) elif config == CONFIG_MAX_WINDOW: assert len(unbatched_outputs) == 4 torch.testing.assert_close( - unbatched_outputs[0], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100]) + unbatched_outputs[0]["labels"], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100]) ) torch.testing.assert_close( - unbatched_outputs[1], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100]) + unbatched_outputs[1]["labels"], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100]) ) torch.testing.assert_close( - unbatched_outputs[2], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100]) + unbatched_outputs[2]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100]) ) torch.testing.assert_close( - unbatched_outputs[3], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100]) + unbatched_outputs[3]["labels"], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100]) ) elif config == CONFIG_PARTITIONS: assert len(unbatched_outputs) == 1 - torch.testing.assert_close(unbatched_outputs[0], torch.tensor([-100, 3, 0, 0, 0, 0, -100])) + torch.testing.assert_close( + unbatched_outputs[0]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, -100]) + ) else: raise ValueError(f"unknown config: {config}") @@ -579,22 +595,22 @@ def test_decode_annotations(taskmodule, unbatched_outputs, config): # We get two annotations for Bob because the window overlaps with the previous one. # This is not a problem because annotations get de-duplicated during serialization. assert annotations == [ - [LabeledSpan(start=1, end=3, label="LOC", score=1.0)], + [LabeledSpan(start=1, end=3, label="LOC")], [], [ - LabeledSpan(start=1, end=2, label="PER", score=1.0), - LabeledSpan(start=6, end=7, label="PER", score=1.0), + LabeledSpan(start=1, end=2, label="PER"), + LabeledSpan(start=6, end=7, label="PER"), ], - [LabeledSpan(start=2, end=3, label="PER", score=1.0)], + [LabeledSpan(start=2, end=3, label="PER")], ] elif config == CONFIG_MAX_WINDOW: assert annotations == [ - [LabeledSpan(start=1, end=3, label="LOC", score=1.0)], + [LabeledSpan(start=1, end=3, label="LOC")], [], [ - LabeledSpan(start=1, end=2, label="PER", score=1.0), - LabeledSpan(start=6, end=7, label="PER", score=1.0), + LabeledSpan(start=1, end=2, label="PER"), + LabeledSpan(start=6, end=7, label="PER"), ], [], ] @@ -605,6 +621,11 @@ def test_decode_annotations(taskmodule, unbatched_outputs, config): else: raise ValueError(f"unknown config: {config}") + # assert that all scores are 0.9 + for doc_annotations in annotations: + for annotation in doc_annotations: + assert round(annotation.score, 4) == 0.9 + @pytest.fixture(scope="module") def annotations_from_output(taskmodule, task_encodings_for_batch, unbatched_outputs, config): @@ -756,10 +777,12 @@ def test_configure_model_metric(documents): "token/micro/f1": tensor(1.0), } - predictions = torch.ones_like(targets) + target_labels = targets["labels"] + predicted_labels = torch.ones_like(target_labels) # we need to set the same padding as in the targets - predictions[targets == taskmodule.label_pad_id] = taskmodule.label_pad_id - metric.update(predictions, targets) + predicted_labels[target_labels == taskmodule.label_pad_id] = taskmodule.label_pad_id + prediction = {"labels": predicted_labels} + metric.update(prediction, targets) values = metric.compute() values_converted = {k: v.item() for k, v in values.items()} assert values_converted == {