From 3923dbe68ba28c94d640fdc272c3a28b0fb04414 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 11 Sep 2024 22:48:56 +0200 Subject: [PATCH 01/49] implement CrossTextBinaryCorefTaskModule --- src/pie_modules/document/types.py | 75 ++++ src/pie_modules/taskmodules/__init__.py | 1 + .../taskmodules/cross_text_binary_coref.py | 208 ++++++++++ .../test_cross_text_binary_coref.py | 367 ++++++++++++++++++ 4 files changed, 651 insertions(+) create mode 100644 src/pie_modules/document/types.py create mode 100644 src/pie_modules/taskmodules/cross_text_binary_coref.py create mode 100644 tests/taskmodules/test_cross_text_binary_coref.py diff --git a/src/pie_modules/document/types.py b/src/pie_modules/document/types.py new file mode 100644 index 000000000..b327a7903 --- /dev/null +++ b/src/pie_modules/document/types.py @@ -0,0 +1,75 @@ +import dataclasses + +from pytorch_ie import AnnotationLayer, annotation_field +from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.documents import ( + TextBasedDocument, + TextDocumentWithLabeledPartitions, + TextDocumentWithLabeledSpans, + TextDocumentWithLabeledSpansAndLabeledPartitions, +) + + +@dataclasses.dataclass +class WithTextPair: + text_pair: str + + +@dataclasses.dataclass +class WithLabeledSpansPair(WithTextPair): + labeled_spans_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") + + +@dataclasses.dataclass +class WithLabeledPartitionsPair(WithTextPair): + labeled_partitions_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") + + +@dataclasses.dataclass +class TextPairBasedDocument(TextBasedDocument, WithTextPair): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledPartitions( + WithLabeledPartitionsPair, TextPairBasedDocument, TextDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpans( + WithLabeledSpansPair, TextPairBasedDocument, TextDocumentWithLabeledSpans +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansAndLabeledPartitions( + TextPairDocumentWithLabeledPartitions, + TextPairDocumentWithLabeledSpans, + TextDocumentWithLabeledSpansAndLabeledPartitions, +): + pass + + +@dataclasses.dataclass(eq=True, frozen=True) +class BinaryCorefRelation(BinaryRelation): + label: str = "coref" + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + TextPairDocumentWithLabeledSpans, TextDocumentWithLabeledSpans +): + binary_coref_relations: AnnotationLayer[BinaryCorefRelation] = annotation_field( + targets=["labeled_spans", "labeled_spans_pair"] + ) + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansSimilarityRelationsAndLabeledPartitions( + TextPairDocumentWithLabeledSpansAndLabeledPartitions, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +): + pass diff --git a/src/pie_modules/taskmodules/__init__.py b/src/pie_modules/taskmodules/__init__.py index 7fdedab4c..46d3766ae 100644 --- a/src/pie_modules/taskmodules/__init__.py +++ b/src/pie_modules/taskmodules/__init__.py @@ -1,3 +1,4 @@ +from .cross_text_binary_coref import CrossTextBinaryCorefTaskModule from .extractive_question_answering import ExtractiveQuestionAnsweringTaskModule from .labeled_span_extraction_by_token_classification import ( LabeledSpanExtractionByTokenClassificationTaskModule, diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py new file mode 100644 index 000000000..b045ee096 --- /dev/null +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -0,0 +1,208 @@ +import logging +from collections import defaultdict +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Optional, + Sequence, + Tuple, + TypedDict, + Union, +) + +import torch +from pytorch_ie import Annotation +from pytorch_ie.core import TaskEncoding, TaskModule +from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize +from torchmetrics import MetricCollection +from torchmetrics.classification import BinaryAUROC +from transformers import AutoTokenizer +from typing_extensions import TypeAlias + +from pie_modules.document.types import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) +from pie_modules.utils import list_of_dicts2dict_of_lists + +logger = logging.getLogger(__name__) + +InputEncodingType: TypeAlias = Dict[str, Any] +TargetEncodingType: TypeAlias = Sequence[float] +DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + +TaskEncodingType: TypeAlias = TaskEncoding[ + DocumentType, + InputEncodingType, + TargetEncodingType, +] + + +class TaskOutputType(TypedDict, total=False): + scores: Sequence[str] + + +ModelInputType: TypeAlias = Dict[str, torch.Tensor] +ModelTargetType: TypeAlias = torch.Tensor +ModelOutputType: TypeAlias = torch.Tensor + +TaskModuleType: TypeAlias = TaskModule[ + # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput + DocumentType, + InputEncodingType, + TargetEncodingType, + Tuple[ModelInputType, Optional[ModelTargetType]], + ModelTargetType, + TaskOutputType, +] + + +@TaskModule.register() +class CrossTextBinaryCorefTaskModule(TaskModuleType, ChangesTokenizerVocabSize): + DOCUMENT_TYPE = DocumentType + + def __init__( + self, + tokenizer_name_or_path: str, + add_negative_relations: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.add_negative_relations = add_negative_relations + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable[DocumentType]: + positive_tuples = defaultdict(set) + text2spans = defaultdict(set) + for doc in positives: + for labeled_span in doc.labeled_spans: + text2spans[doc.text].add(labeled_span.copy()) + for labeled_span in doc.labeled_spans_pair: + text2spans[doc.text_pair].add(labeled_span.copy()) + + for coref in doc.binary_coref_relations: + positive_tuples[(doc.text, doc.text_pair)].add( + (coref.head.copy(), coref.tail.copy()) + ) + positive_tuples[(doc.text_pair, doc.text)].add( + (coref.tail.copy(), coref.head.copy()) + ) + + new_docs = [] + for text in sorted(text2spans): + for text_pair in sorted(text2spans): + if text == text_pair: + continue + current_positives = positive_tuples.get((text, text_pair), set()) + new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + text=text, text_pair=text_pair + ) + new_doc.labeled_spans.extend( + labeled_span.copy() for labeled_span in text2spans[text] + ) + new_doc.labeled_spans_pair.extend( + labeled_span.copy() for labeled_span in text2spans[text_pair] + ) + for s in sorted(new_doc.labeled_spans): + for s_p in sorted(new_doc.labeled_spans_pair): + score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 + new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) + new_doc.binary_coref_relations.append(new_coref_rel) + new_docs.append(new_doc) + + return new_docs + + def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): + if self.add_negative_relations: + if isinstance(documents, DocumentType): + documents = [documents] + documents = self._add_negative_relations(documents) + + return super().encode(documents=documents, **kwargs) + + def encode_input( + self, + document: DocumentType, + is_training: bool = False, + ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: + tokenizer_kwargs = dict( + padding=False, + truncation=True, + max_length=self.tokenizer.model_max_length, + return_offsets_mapping=False, + add_special_tokens=True, + ) + encoding = self.tokenizer(text=document.text, **tokenizer_kwargs) + encoding_pair = self.tokenizer(text=document.text_pair, **tokenizer_kwargs) + + task_encodings = [] + for coref_rel in document.binary_coref_relations: + start = encoding.char_to_token(coref_rel.head.start) + end = encoding.char_to_token(coref_rel.head.end - 1) + 1 + start_pair = encoding_pair.char_to_token(coref_rel.tail.start) + end_pair = encoding_pair.char_to_token(coref_rel.tail.end - 1) + 1 + if any(offset is None for offset in [start, end, start_pair, end_pair]): + logger.warning( + f"Could not get token offsets for arguments of coref relation: {coref_rel.resolve()}. Skip it." + ) + continue + task_encodings.append( + TaskEncoding( + document=document, + inputs={ + "encoding": encoding, + "encoding_pair": encoding_pair, + "start": start, + "end": end, + "start_pair": start_pair, + "end_pair": end_pair, + }, + metadata={"candidate_annotation": coref_rel}, + ) + ) + return task_encodings + + def encode_target( + self, + task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], + ) -> Optional[TargetEncodingType]: + return task_encoding.metadata["candidate_annotation"].score + + def collate( + self, + task_encodings: Sequence[ + TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType] + ], + ) -> Tuple[ModelInputType, Optional[ModelTargetType]]: + inputs_dict = list_of_dicts2dict_of_lists( + [task_encoding.inputs for task_encoding in task_encodings] + ) + + inputs = { + k: self.tokenizer.pad(v, return_tensors="pt") + if k in ["encoding", "encoding_pair"] + else torch.tensor(v) + for k, v in inputs_dict.items() + } + + if not task_encodings[0].has_targets: + return inputs, None + targets = torch.tensor([task_encoding.targets for task_encoding in task_encodings]) + return inputs, targets + + def configure_model_metric(self, stage: str) -> MetricCollection: + return MetricCollection({"auroc": BinaryAUROC(thresholds=None)}) + + def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: + raise NotImplementedError() + + def create_annotations_from_output( + self, + task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], + task_output: TaskOutputType, + ) -> Iterator[Tuple[str, Annotation]]: + raise NotImplementedError() diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py new file mode 100644 index 000000000..1e51605ba --- /dev/null +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -0,0 +1,367 @@ +from typing import Any, Dict, Union + +import pytest +import torch.testing +from pytorch_ie.annotations import LabeledSpan +from torchmetrics import Metric, MetricCollection + +from pie_modules.document.types import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) +from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule +from pie_modules.utils import flatten_dict, list_of_dicts2dict_of_lists +from tests import _config_to_str + +TOKENIZER_NAME_OR_PATH = "bert-base-cased" + +CONFIGS = [ + {}, + # {"add_negative_relations": True}, +] +CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} + + +@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) +def config(request): + return CONFIGS_DICT[request.param] + + +@pytest.fixture(scope="module") +def positive_documents(): + doc1 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Entity A works at B.", text_pair="And she founded C." + ) + doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) + doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) + doc1.labeled_spans_pair.append(LabeledSpan(start=4, end=7, label="PERSON")) + doc1.labeled_spans_pair.append(LabeledSpan(start=16, end=17, label="COMPANY")) + doc1.binary_coref_relations.append( + BinaryCorefRelation(head=doc1.labeled_spans[0], tail=doc1.labeled_spans_pair[0]) + ) + + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + doc2.binary_coref_relations.append( + BinaryCorefRelation(head=doc2.labeled_spans[1], tail=doc2.labeled_spans_pair[0]) + ) + + return [doc1, doc2] + + +def test_positive_documents(positive_documents): + assert len(positive_documents) == 2 + doc1, doc2 = positive_documents + assert doc1.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] + assert doc1.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc1.binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + assert doc2.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] + assert doc2.labeled_spans_pair.resolve() == [("ANIMAL", "She")] + assert doc2.binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + +@pytest.fixture(scope="module") +def unprepared_taskmodule(config): + taskmodule = CrossTextBinaryCorefTaskModule( + tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, **config + ) + assert not taskmodule.is_from_pretrained + + return taskmodule + + +@pytest.fixture(scope="module") +def taskmodule(unprepared_taskmodule, positive_documents): + unprepared_taskmodule.prepare(positive_documents) + return unprepared_taskmodule + + +@pytest.fixture(scope="module") +def documents_with_negatives(taskmodule, positive_documents): + return list(taskmodule._add_negative_relations(positive_documents)) + + +def test_construct_negative_documents(positive_documents, documents_with_negatives): + assert len(positive_documents) == 2 + TEXTS = [ + "Entity A works at B.", + "And she founded C.", + "Bob loves his cat.", + "She sleeps a lot.", + ] + assert len(documents_with_negatives) == 12 + all_scores = [ + [coref_rel.score for coref_rel in doc.binary_coref_relations] + for doc in documents_with_negatives + ] + assert documents_with_negatives[0].text == TEXTS[1] + assert documents_with_negatives[0].text_pair == TEXTS[2] + assert documents_with_negatives[0].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "she"), ("PERSON", "Bob"))), + ("coref", (("PERSON", "she"), ("ANIMAL", "his cat"))), + ("coref", (("COMPANY", "C"), ("PERSON", "Bob"))), + ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat"))), + ] + assert all_scores[0] == [0.0, 0.0, 0.0, 0.0] + + assert documents_with_negatives[1].text == TEXTS[1] + assert documents_with_negatives[1].text_pair == TEXTS[0] + assert documents_with_negatives[1].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))), + ("coref", (("PERSON", "she"), ("COMPANY", "B"))), + ("coref", (("COMPANY", "C"), ("PERSON", "Entity A"))), + ("coref", (("COMPANY", "C"), ("COMPANY", "B"))), + ] + assert all_scores[1] == [1.0, 0.0, 0.0, 0.0] + + assert documents_with_negatives[2].text == TEXTS[1] + assert documents_with_negatives[2].text_pair == TEXTS[3] + assert documents_with_negatives[2].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "she"), ("ANIMAL", "She"))), + ("coref", (("COMPANY", "C"), ("ANIMAL", "She"))), + ] + assert all_scores[2] == [0.0, 0.0] + + assert documents_with_negatives[3].text == TEXTS[2] + assert documents_with_negatives[3].text_pair == TEXTS[1] + assert documents_with_negatives[3].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Bob"), ("PERSON", "she"))), + ("coref", (("PERSON", "Bob"), ("COMPANY", "C"))), + ("coref", (("ANIMAL", "his cat"), ("PERSON", "she"))), + ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C"))), + ] + assert all_scores[3] == [0.0, 0.0, 0.0, 0.0] + + assert documents_with_negatives[4].text == TEXTS[2] + assert documents_with_negatives[4].text_pair == TEXTS[0] + assert documents_with_negatives[4].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))), + ("coref", (("PERSON", "Bob"), ("COMPANY", "B"))), + ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A"))), + ("coref", (("ANIMAL", "his cat"), ("COMPANY", "B"))), + ] + assert all_scores[4] == [0.0, 0.0, 0.0, 0.0] + + assert documents_with_negatives[5].text == TEXTS[2] + assert documents_with_negatives[5].text_pair == TEXTS[3] + assert documents_with_negatives[5].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Bob"), ("ANIMAL", "She"))), + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))), + ] + assert all_scores[5] == [0.0, 1.0] + + assert documents_with_negatives[6].text == TEXTS[0] + assert documents_with_negatives[6].text_pair == TEXTS[1] + assert documents_with_negatives[6].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))), + ("coref", (("PERSON", "Entity A"), ("COMPANY", "C"))), + ("coref", (("COMPANY", "B"), ("PERSON", "she"))), + ("coref", (("COMPANY", "B"), ("COMPANY", "C"))), + ] + assert all_scores[6] == [1.0, 0.0, 0.0, 0.0] + + assert documents_with_negatives[7].text == TEXTS[0] + assert documents_with_negatives[7].text_pair == TEXTS[2] + assert documents_with_negatives[7].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))), + ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat"))), + ("coref", (("COMPANY", "B"), ("PERSON", "Bob"))), + ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat"))), + ] + assert all_scores[7] == [0.0, 0.0, 0.0, 0.0] + + assert documents_with_negatives[8].text == TEXTS[0] + assert documents_with_negatives[8].text_pair == TEXTS[3] + assert documents_with_negatives[8].binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She"))), + ("coref", (("COMPANY", "B"), ("ANIMAL", "She"))), + ] + assert all_scores[8] == [0.0, 0.0] + + assert documents_with_negatives[9].text == TEXTS[3] + assert documents_with_negatives[9].text_pair == TEXTS[1] + assert documents_with_negatives[9].binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "She"), ("PERSON", "she"))), + ("coref", (("ANIMAL", "She"), ("COMPANY", "C"))), + ] + assert all_scores[9] == [0.0, 0.0] + + assert documents_with_negatives[10].text == TEXTS[3] + assert documents_with_negatives[10].text_pair == TEXTS[2] + assert documents_with_negatives[10].binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "She"), ("PERSON", "Bob"))), + ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))), + ] + assert all_scores[10] == [0.0, 1.0] + + assert documents_with_negatives[11].text == TEXTS[3] + assert documents_with_negatives[11].text_pair == TEXTS[0] + assert documents_with_negatives[11].binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A"))), + ("coref", (("ANIMAL", "She"), ("COMPANY", "B"))), + ] + assert all_scores[11] == [0.0, 0.0] + + +@pytest.fixture(scope="module") +def task_encodings_without_target(taskmodule, documents_with_negatives): + task_encodings = taskmodule.encode_input(documents_with_negatives[0]) + return task_encodings + + +def test_encode_input(task_encodings_without_target, taskmodule): + task_encodings = task_encodings_without_target + convert_ids_to_tokens = taskmodule.tokenizer.convert_ids_to_tokens + + inputs_dict = list_of_dicts2dict_of_lists( + [task_encoding.inputs for task_encoding in task_encodings] + ) + tokens = [convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding"]] + tokens_pair = [ + convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding_pair"] + ] + assert tokens == [ + ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], + ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], + ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], + ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], + ] + assert tokens_pair == [ + ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], + ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], + ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], + ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], + ] + span_tokens = [ + toks[start:end] + for toks, start, end in zip(tokens, inputs_dict["start"], inputs_dict["end"]) + ] + span_tokens_pair = [ + toks[start:end] + for toks, start, end in zip( + tokens_pair, inputs_dict["start_pair"], inputs_dict["end_pair"] + ) + ] + assert span_tokens == [["she"], ["she"], ["C"], ["C"]] + assert span_tokens_pair == [["Bob"], ["his", "cat"], ["Bob"], ["his", "cat"]] + + +def test_encode_target(task_encodings_without_target, taskmodule): + target = taskmodule.encode_target(task_encodings_without_target[0]) + assert target == 0.0 + + +@pytest.fixture(scope="module", params=[False, True]) +def batch(taskmodule, positive_documents, documents_with_negatives, request): + if request.param: + original_value = taskmodule.add_negative_relations + taskmodule.add_negative_relations = True + task_encodings = taskmodule.encode(positive_documents, encode_target=True)[:4] + taskmodule.add_negative_relations = original_value + else: + task_encodings = taskmodule.encode(documents_with_negatives[0], encode_target=True) + result = taskmodule.collate(task_encodings) + return result + + +def test_collate(batch, taskmodule): + assert batch is not None + inputs, targets = batch + assert inputs is not None + assert set(inputs) == {"encoding", "encoding_pair", "start", "end", "start_pair", "end_pair"} + torch.testing.assert_close( + inputs["encoding"]["input_ids"], + torch.tensor( + [ + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + ] + ), + ) + torch.testing.assert_close( + inputs["encoding"]["token_type_ids"], torch.zeros_like(inputs["encoding"]["input_ids"]) + ) + torch.testing.assert_close( + inputs["encoding"]["attention_mask"], torch.ones_like(inputs["encoding"]["input_ids"]) + ) + + torch.testing.assert_close( + inputs["encoding_pair"]["input_ids"], + torch.tensor( + [ + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + ] + ), + ) + torch.testing.assert_close( + inputs["encoding_pair"]["token_type_ids"], + torch.zeros_like(inputs["encoding_pair"]["input_ids"]), + ) + torch.testing.assert_close( + inputs["encoding_pair"]["attention_mask"], + torch.ones_like(inputs["encoding_pair"]["input_ids"]), + ) + + torch.testing.assert_close(inputs["start"], torch.tensor([2, 2, 4, 4])) + torch.testing.assert_close(inputs["end"], torch.tensor([3, 3, 5, 5])) + torch.testing.assert_close(inputs["start_pair"], torch.tensor([1, 3, 1, 3])) + torch.testing.assert_close(inputs["end_pair"], torch.tensor([2, 5, 2, 5])) + + torch.testing.assert_close(targets, torch.tensor([0.0, 0.0, 0.0, 0.0])) + + +def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: + if isinstance(metric_or_collection, Metric): + return { + k: [vv.tolist() for vv in v] + for k, v in flatten_dict(metric_or_collection.metric_state).items() + } + elif isinstance(metric_or_collection, MetricCollection): + return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()}) + else: + raise ValueError(f"unsupported type: {type(metric_or_collection)}") + + +def test_configure_metric(taskmodule, batch): + metric = taskmodule.configure_model_metric(stage="train") + + assert isinstance(metric, (Metric, MetricCollection)) + state = get_metric_state(metric) + assert state == {"auroc/preds": [], "auroc/target": []} + + # targets = batch[1] + targets = torch.tensor([0.0, 1.0, 0.0, 0.0]) + metric.update(targets, targets) + + state = get_metric_state(metric) + assert state == {"auroc/preds": [[0.0, 1.0, 0.0, 0.0]], "auroc/target": [[0.0, 1.0, 0.0, 0.0]]} + + assert metric.compute() == {"auroc": torch.tensor(1.0)} + + # torch.rand_like(targets) + random_targets = torch.tensor([0.2703, 0.6812, 0.2582, 0.8030]) + metric.update(random_targets, targets) + state = get_metric_state(metric) + assert state == { + "auroc/preds": [ + [0.0, 1.0, 0.0, 0.0], + [0.2703000009059906, 0.6812000274658203, 0.2581999897956848, 0.8029999732971191], + ], + "auroc/target": [[0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + } + + assert metric.compute() == {"auroc": torch.tensor(0.9166666269302368)} From 52a8b45136f33c40bcfe9348bec5b612e4bd71fc Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 11 Sep 2024 22:50:53 +0200 Subject: [PATCH 02/49] call save_hyperparameters() --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index b045ee096..34b9268e7 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -70,6 +70,7 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + self.save_hyperparameters() self.add_negative_relations = add_negative_relations From 71fb6e1cfe6118b9a1f83094a83f2ff5d5fd8c4a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 11 Sep 2024 23:53:24 +0200 Subject: [PATCH 03/49] make taskmodule (future) model compliant --- .../taskmodules/cross_text_binary_coref.py | 33 +++++--- .../test_cross_text_binary_coref.py | 75 +++++++++++++------ 2 files changed, 76 insertions(+), 32 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 34b9268e7..968be53bf 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -16,7 +16,7 @@ from pytorch_ie import Annotation from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize -from torchmetrics import MetricCollection +from torchmetrics import Metric, MetricCollection from torchmetrics.classification import BinaryAUROC from transformers import AutoTokenizer from typing_extensions import TypeAlias @@ -25,6 +25,7 @@ BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) +from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction from pie_modules.utils import list_of_dicts2dict_of_lists logger = logging.getLogger(__name__) @@ -45,8 +46,8 @@ class TaskOutputType(TypedDict, total=False): ModelInputType: TypeAlias = Dict[str, torch.Tensor] -ModelTargetType: TypeAlias = torch.Tensor -ModelOutputType: TypeAlias = torch.Tensor +ModelTargetType: TypeAlias = Dict[str, torch.Tensor] +ModelOutputType: TypeAlias = Dict[str, torch.Tensor] TaskModuleType: TypeAlias = TaskModule[ # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput @@ -59,6 +60,10 @@ class TaskOutputType(TypedDict, total=False): ] +def _get_labels(model_output: ModelTargetType) -> torch.Tensor: + return model_output["labels"] + + @TaskModule.register() class CrossTextBinaryCorefTaskModule(TaskModuleType, ChangesTokenizerVocabSize): DOCUMENT_TYPE = DocumentType @@ -157,10 +162,10 @@ def encode_input( inputs={ "encoding": encoding, "encoding_pair": encoding_pair, - "start": start, - "end": end, - "start_pair": start_pair, - "end_pair": end_pair, + "pooler_start_indices": start, + "pooler_end_indices": end, + "pooler_start_indices_pair": start_pair, + "pooler_end_indices_pair": end_pair, }, metadata={"candidate_annotation": coref_rel}, ) @@ -189,14 +194,22 @@ def collate( else torch.tensor(v) for k, v in inputs_dict.items() } + for k, v in inputs.items(): + if k.startswith("pooler_start_indices") or k.startswith("pooler_end_indices"): + inputs[k] = v.unsqueeze(-1) if not task_encodings[0].has_targets: return inputs, None - targets = torch.tensor([task_encoding.targets for task_encoding in task_encodings]) + targets = { + "labels": torch.tensor([task_encoding.targets for task_encoding in task_encodings]) + } return inputs, targets - def configure_model_metric(self, stage: str) -> MetricCollection: - return MetricCollection({"auroc": BinaryAUROC(thresholds=None)}) + def configure_model_metric(self, stage: str) -> Metric: + return WrappedMetricWithPrepareFunction( + metric=MetricCollection({"auroc": BinaryAUROC(thresholds=None)}), + prepare_function=_get_labels, + ) def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: raise NotImplementedError() diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 1e51605ba..7415261a1 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -243,12 +243,16 @@ def test_encode_input(task_encodings_without_target, taskmodule): ] span_tokens = [ toks[start:end] - for toks, start, end in zip(tokens, inputs_dict["start"], inputs_dict["end"]) + for toks, start, end in zip( + tokens, inputs_dict["pooler_start_indices"], inputs_dict["pooler_end_indices"] + ) ] span_tokens_pair = [ toks[start:end] for toks, start, end in zip( - tokens_pair, inputs_dict["start_pair"], inputs_dict["end_pair"] + tokens_pair, + inputs_dict["pooler_start_indices_pair"], + inputs_dict["pooler_end_indices_pair"], ) ] assert span_tokens == [["she"], ["she"], ["C"], ["C"]] @@ -277,7 +281,14 @@ def test_collate(batch, taskmodule): assert batch is not None inputs, targets = batch assert inputs is not None - assert set(inputs) == {"encoding", "encoding_pair", "start", "end", "start_pair", "end_pair"} + assert set(inputs) == { + "pooler_end_indices", + "encoding_pair", + "pooler_end_indices_pair", + "pooler_start_indices", + "encoding", + "pooler_start_indices_pair", + } torch.testing.assert_close( inputs["encoding"]["input_ids"], torch.tensor( @@ -316,20 +327,21 @@ def test_collate(batch, taskmodule): torch.ones_like(inputs["encoding_pair"]["input_ids"]), ) - torch.testing.assert_close(inputs["start"], torch.tensor([2, 2, 4, 4])) - torch.testing.assert_close(inputs["end"], torch.tensor([3, 3, 5, 5])) - torch.testing.assert_close(inputs["start_pair"], torch.tensor([1, 3, 1, 3])) - torch.testing.assert_close(inputs["end_pair"], torch.tensor([2, 5, 2, 5])) + torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [2], [4], [4]])) + torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [3], [5], [5]])) + torch.testing.assert_close( + inputs["pooler_start_indices_pair"], torch.tensor([[1], [3], [1], [3]]) + ) + torch.testing.assert_close( + inputs["pooler_end_indices_pair"], torch.tensor([[2], [5], [2], [5]]) + ) - torch.testing.assert_close(targets, torch.tensor([0.0, 0.0, 0.0, 0.0])) + torch.testing.assert_close(targets, {"labels": torch.tensor([0.0, 0.0, 0.0, 0.0])}) def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: if isinstance(metric_or_collection, Metric): - return { - k: [vv.tolist() for vv in v] - for k, v in flatten_dict(metric_or_collection.metric_state).items() - } + return flatten_dict(metric_or_collection.metric_state) elif isinstance(metric_or_collection, MetricCollection): return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()}) else: @@ -344,24 +356,43 @@ def test_configure_metric(taskmodule, batch): assert state == {"auroc/preds": [], "auroc/target": []} # targets = batch[1] - targets = torch.tensor([0.0, 1.0, 0.0, 0.0]) + targets = {"labels": torch.tensor([0.0, 1.0, 0.0, 0.0])} metric.update(targets, targets) state = get_metric_state(metric) - assert state == {"auroc/preds": [[0.0, 1.0, 0.0, 0.0]], "auroc/target": [[0.0, 1.0, 0.0, 0.0]]} + torch.testing.assert_close( + state, + { + "auroc/preds": [torch.tensor([0.0, 1.0, 0.0, 0.0])], + "auroc/target": [torch.tensor([0.0, 1.0, 0.0, 0.0])], + }, + ) assert metric.compute() == {"auroc": torch.tensor(1.0)} # torch.rand_like(targets) - random_targets = torch.tensor([0.2703, 0.6812, 0.2582, 0.8030]) + random_targets = {"labels": torch.tensor([0.2703, 0.6812, 0.2582, 0.8030])} metric.update(random_targets, targets) state = get_metric_state(metric) - assert state == { - "auroc/preds": [ - [0.0, 1.0, 0.0, 0.0], - [0.2703000009059906, 0.6812000274658203, 0.2581999897956848, 0.8029999732971191], - ], - "auroc/target": [[0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], - } + torch.testing.assert_close( + state, + { + "auroc/preds": [ + torch.tensor([0.0, 1.0, 0.0, 0.0]), + torch.tensor( + [ + 0.2703000009059906, + 0.6812000274658203, + 0.2581999897956848, + 0.8029999732971191, + ] + ), + ], + "auroc/target": [ + torch.tensor([0.0, 1.0, 0.0, 0.0]), + torch.tensor([0.0, 1.0, 0.0, 0.0]), + ], + }, + ) assert metric.compute() == {"auroc": torch.tensor(0.9166666269302368)} From abec89694234b209091c039cd85cac699437189e Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 00:06:07 +0200 Subject: [PATCH 04/49] implement SimpleSimilarityModel --- src/pie_modules/models/simple_similarity.py | 210 ++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 src/pie_modules/models/simple_similarity.py diff --git a/src/pie_modules/models/simple_similarity.py b/src/pie_modules/models/simple_similarity.py new file mode 100644 index 000000000..6d3766f2c --- /dev/null +++ b/src/pie_modules/models/simple_similarity.py @@ -0,0 +1,210 @@ +import logging +from typing import Any, Dict, Iterator, MutableMapping, Optional, Tuple, Union + +import torch +from pytorch_ie.core import PyTorchIEModel +from pytorch_ie.models.interface import RequiresModelNameOrPath +from torch import FloatTensor, LongTensor, nn +from torch.nn import BCELoss, Parameter +from torch.optim import AdamW +from transformers import AutoConfig, AutoModel, get_linear_schedule_with_warmup +from transformers.modeling_outputs import SequenceClassifierOutput +from typing_extensions import TypeAlias + +from .common import ModelWithBoilerplate +from .components.pooler import get_pooler_and_output_size + +# model inputs / outputs / targets +InputType: TypeAlias = MutableMapping[str, LongTensor] +OutputType: TypeAlias = SequenceClassifierOutput +TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] +# step inputs (batch) / outputs (loss) +StepInputType: TypeAlias = Tuple[InputType, TargetType] +StepOutputType: TypeAlias = FloatTensor + + +HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = { + "albert": "classifier_dropout_prob", + "distilbert": "seq_classif_dropout", +} + +logger = logging.getLogger(__name__) + + +@PyTorchIEModel.register() +class SimpleSimilarityModel( + ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], + RequiresModelNameOrPath, +): + """TODO. + + Args: + model_name_or_path: The name or path of the HuggingFace model to use. + tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's + tokenizer embeddings are resized to this size. + classifier_dropout: The dropout probability for the classifier. If not provided, the + dropout probability is taken from the Huggingface model config. + learning_rate: The learning rate for the optimizer. + task_learning_rate: The learning rate for the task-specific parameters. If None, the + learning rate for all parameters is set to `learning_rate`. + warmup_proportion: The proportion of steps to warm up the learning rate. + multi_label: If True, the model is trained as a multi-label classifier. + multi_label_threshold: The threshold for the multi-label classifier, i.e. the probability + above which a class is predicted. + pooler: The pooler configuration. If None, CLS token pooling is used. + freeze_base_model: If True, the base model parameters are frozen. + base_model_prefix: The prefix of the base model parameters when using a task_learning_rate + or freeze_base_model. If None, the base_model_prefix of the model is used. + **kwargs: Additional keyword arguments passed to the parent class, + see :class:`ModelWithBoilerplate`. + """ + + def __init__( + self, + model_name_or_path: str, + tokenizer_vocab_size: Optional[int] = None, + classifier_dropout: Optional[float] = None, + learning_rate: float = 1e-5, + task_learning_rate: Optional[float] = None, + warmup_proportion: float = 0.1, + # TODO: use "mention_pooling" per default? + pooler: Optional[Union[Dict[str, Any], str]] = None, + freeze_base_model: bool = False, + hidden_dim: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.save_hyperparameters() + + self.learning_rate = learning_rate + self.task_learning_rate = task_learning_rate + self.warmup_proportion = warmup_proportion + self.freeze_base_model = freeze_base_model + + config = AutoConfig.from_pretrained(model_name_or_path) + if self.is_from_pretrained: + self.model = AutoModel.from_config(config=config) + else: + self.model = AutoModel.from_pretrained(model_name_or_path, config=config) + + if tokenizer_vocab_size is not None: + self.model.resize_token_embeddings(tokenizer_vocab_size) + + if self.freeze_base_model: + for param in self.model.parameters(): + param.requires_grad = False + + if classifier_dropout is None: + # Get the classifier dropout value from the Huggingface model config. + # This is a bit of a mess since some Configs use different variable names or change the semantics + # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a + # general one for embeddings, encoder and pooler). + classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get( + config.model_type, "classifier_dropout" + ) + classifier_dropout = getattr(config, classifier_dropout_attr) or 0.0 + self.dropout = nn.Dropout(classifier_dropout) + + if isinstance(pooler, str): + pooler = {"type": pooler} + self.pooler_config = pooler or {} + self.pooler, pooler_output_dim = get_pooler_and_output_size( + config=self.pooler_config, + input_dim=config.hidden_size, + ) + if hidden_dim is not None: + self.classifier = nn.Linear(pooler_output_dim, hidden_dim) + else: + self.classifier = None + + # TODO: is this ok? + self.loss_fct = BCELoss() + + def get_pooled_output(self, model_inputs, pooler_inputs): + output = self.model(**model_inputs) + hidden_state = output.last_hidden_state + pooled_output = self.pooler(hidden_state, **pooler_inputs) + pooled_output = self.dropout(pooled_output) + if self.classifier is not None: + return self.classifier(pooled_output) + return pooled_output + + def forward( + self, + inputs: InputType, + targets: Optional[TargetType] = None, + return_hidden_states: bool = False, + ) -> OutputType: + model_inputs = None + model_inputs_pair = None + pooler_inputs = {} + pooler_inputs_pair = {} + for k, v in inputs.items(): + if k.startswith("pooler_") and k.endswith("_pair"): + k_target = k[len("pooler_") : -len("_pair")] + pooler_inputs_pair[k_target] = v + elif k.startswith("pooler_"): + k_target = k[len("pooler_") :] + pooler_inputs[k_target] = v + elif k == "encoding": + model_inputs = v + elif k == "encoding_pair": + model_inputs_pair = v + else: + raise ValueError(f"unexpected model input: {k}") + + pooled_output = self.get_pooled_output(model_inputs, pooler_inputs) + pooled_output_pair = self.get_pooled_output(model_inputs_pair, pooler_inputs_pair) + + logits = torch.nn.functional.cosine_similarity(pooled_output, pooled_output_pair) + + result = {"logits": logits} + if targets is not None: + labels = targets["labels"] + loss = self.loss_fct(logits, labels) + result["loss"] = loss + if return_hidden_states: + raise NotImplementedError("return_hidden_states is not yet implemented") + + return SequenceClassifierOutput(**result) + + def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: + labels = (outputs.logits > 0.5).to(torch.long) + + return {"labels": labels, "probabilities": outputs.logits} + + def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: + if prefix: + prefix = f"{prefix}." + return self.model.named_parameters(prefix=f"{prefix}model") + + def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: + if prefix: + prefix = f"{prefix}." + base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys() + for name, param in self.named_parameters(prefix=prefix): + if name not in base_model_parameter_names: + yield name, param + + def configure_optimizers(self): + if self.task_learning_rate is not None: + base_model_params = (param for name, param in self.base_model_named_parameters()) + task_params = (param for name, param in self.task_named_parameters()) + optimizer = AdamW( + [ + {"params": base_model_params, "lr": self.learning_rate}, + {"params": task_params, "lr": self.task_learning_rate}, + ] + ) + else: + optimizer = AdamW(self.parameters(), lr=self.learning_rate) + + if self.warmup_proportion > 0.0: + stepping_batches = self.trainer.estimated_stepping_batches + scheduler = get_linear_schedule_with_warmup( + optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches + ) + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + else: + return optimizer From 16caf818192c5e22aabee62db37d1f8ef6150826 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 13:36:17 +0200 Subject: [PATCH 05/49] use fixture data for documents_with_negatives --- .../documents_with_negatives.json | 800 ++++++++++++++++++ .../test_cross_text_binary_coref.py | 108 ++- 2 files changed, 863 insertions(+), 45 deletions(-) create mode 100644 tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json diff --git a/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json b/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json new file mode 100644 index 000000000..0e0b4062f --- /dev/null +++ b/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json @@ -0,0 +1,800 @@ +[ + { + "text_pair": "Bob loves his cat.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": -5246751469876588720 + }, + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": 3043206444225553475 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": 5373078146820384347 + }, + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -3679976720952382748 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -5246751469876588720, + "tail": -3679976720952382748, + "label": "coref", + "score": 0.0, + "_id": -1226852003320818417 + }, + { + "head": -5246751469876588720, + "tail": 5373078146820384347, + "label": "coref", + "score": 0.0, + "_id": -2897381892745677680 + }, + { + "head": 3043206444225553475, + "tail": -3679976720952382748, + "label": "coref", + "score": 0.0, + "_id": 4747715004687052922 + }, + { + "head": 3043206444225553475, + "tail": 5373078146820384347, + "label": "coref", + "score": 0.0, + "_id": 8355440541443623552 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": -5246751469876588720 + }, + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": 3043206444225553475 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": 3233654095506762724 + }, + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": -2183238448703307780 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -5246751469876588720, + "tail": 3233654095506762724, + "label": "coref", + "score": 1.0, + "_id": -4357456139038854264 + }, + { + "head": -5246751469876588720, + "tail": -2183238448703307780, + "label": "coref", + "score": 0.0, + "_id": 466813473723110234 + }, + { + "head": 3043206444225553475, + "tail": 3233654095506762724, + "label": "coref", + "score": 0.0, + "_id": -4272399218893089512 + }, + { + "head": 3043206444225553475, + "tail": -2183238448703307780, + "label": "coref", + "score": 0.0, + "_id": -5602326156476594098 + } + ], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": -5246751469876588720 + }, + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": 3043206444225553475 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": -190677143789164847 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -5246751469876588720, + "tail": -190677143789164847, + "label": "coref", + "score": 0.0, + "_id": -8510476152511400278 + }, + { + "head": 3043206444225553475, + "tail": -190677143789164847, + "label": "coref", + "score": 0.0, + "_id": 3867128290998447006 + } + ], + "predictions": [] + } + }, + { + "text_pair": "And she founded C.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": 5373078146820384347 + }, + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -3679976720952382748 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": -5246751469876588720 + }, + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": 3043206444225553475 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -3679976720952382748, + "tail": -5246751469876588720, + "label": "coref", + "score": 0.0, + "_id": -2767191573101294319 + }, + { + "head": -3679976720952382748, + "tail": 3043206444225553475, + "label": "coref", + "score": 0.0, + "_id": -4437612686117351921 + }, + { + "head": 5373078146820384347, + "tail": -5246751469876588720, + "label": "coref", + "score": 0.0, + "_id": 5020739476238125539 + }, + { + "head": 5373078146820384347, + "tail": 3043206444225553475, + "label": "coref", + "score": 0.0, + "_id": 3902122025380513665 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": 5373078146820384347 + }, + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -3679976720952382748 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": 3233654095506762724 + }, + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": -2183238448703307780 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -3679976720952382748, + "tail": 3233654095506762724, + "label": "coref", + "score": 0.0, + "_id": -8901055438221583123 + }, + { + "head": -3679976720952382748, + "tail": -2183238448703307780, + "label": "coref", + "score": 0.0, + "_id": -8898764560981633135 + }, + { + "head": 5373078146820384347, + "tail": 3233654095506762724, + "label": "coref", + "score": 0.0, + "_id": -1347933737476127508 + }, + { + "head": 5373078146820384347, + "tail": -2183238448703307780, + "label": "coref", + "score": 0.0, + "_id": 3930515724475035731 + } + ], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": 5373078146820384347 + }, + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -3679976720952382748 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": -190677143789164847 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -3679976720952382748, + "tail": -190677143789164847, + "label": "coref", + "score": 0.0, + "_id": -5159885311314414733 + }, + { + "head": 5373078146820384347, + "tail": -190677143789164847, + "label": "coref", + "score": 1.0, + "_id": -4858368627143918533 + } + ], + "predictions": [] + } + }, + { + "text_pair": "And she founded C.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": 3233654095506762724 + }, + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": -2183238448703307780 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": -5246751469876588720 + }, + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": 3043206444225553475 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": 3233654095506762724, + "tail": -5246751469876588720, + "label": "coref", + "score": 1.0, + "_id": 2444090963512005184 + }, + { + "head": 3233654095506762724, + "tail": 3043206444225553475, + "label": "coref", + "score": 0.0, + "_id": -7963340116969175614 + }, + { + "head": -2183238448703307780, + "tail": -5246751469876588720, + "label": "coref", + "score": 0.0, + "_id": -9120191367688252721 + }, + { + "head": -2183238448703307780, + "tail": 3043206444225553475, + "label": "coref", + "score": 0.0, + "_id": 7975222748039939420 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": 3233654095506762724 + }, + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": -2183238448703307780 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": 5373078146820384347 + }, + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -3679976720952382748 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": 3233654095506762724, + "tail": -3679976720952382748, + "label": "coref", + "score": 0.0, + "_id": 1280608060947850168 + }, + { + "head": 3233654095506762724, + "tail": 5373078146820384347, + "label": "coref", + "score": 0.0, + "_id": -3000515518015844819 + }, + { + "head": -2183238448703307780, + "tail": -3679976720952382748, + "label": "coref", + "score": 0.0, + "_id": -4464070305304755517 + }, + { + "head": -2183238448703307780, + "tail": 5373078146820384347, + "label": "coref", + "score": 0.0, + "_id": 3298512753939125167 + } + ], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": 3233654095506762724 + }, + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": -2183238448703307780 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": -190677143789164847 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": 3233654095506762724, + "tail": -190677143789164847, + "label": "coref", + "score": 0.0, + "_id": -3444435532096461506 + }, + { + "head": -2183238448703307780, + "tail": -190677143789164847, + "label": "coref", + "score": 0.0, + "_id": -3912955313637853940 + } + ], + "predictions": [] + } + }, + { + "text_pair": "And she founded C.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": -190677143789164847 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": -5246751469876588720 + }, + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": 3043206444225553475 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -190677143789164847, + "tail": -5246751469876588720, + "label": "coref", + "score": 0.0, + "_id": -6992824161873864749 + }, + { + "head": -190677143789164847, + "tail": 3043206444225553475, + "label": "coref", + "score": 0.0, + "_id": 6180444938490764939 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": -190677143789164847 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": 5373078146820384347 + }, + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -3679976720952382748 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -190677143789164847, + "tail": -3679976720952382748, + "label": "coref", + "score": 0.0, + "_id": 2061654283494000583 + }, + { + "head": -190677143789164847, + "tail": 5373078146820384347, + "label": "coref", + "score": 1.0, + "_id": -4650461605955518398 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": -190677143789164847 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": 3233654095506762724 + }, + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": -2183238448703307780 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -190677143789164847, + "tail": 3233654095506762724, + "label": "coref", + "score": 0.0, + "_id": 8092666078797453961 + }, + { + "head": -190677143789164847, + "tail": -2183238448703307780, + "label": "coref", + "score": 0.0, + "_id": -5075628532960934416 + } + ], + "predictions": [] + } + } +] diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 7415261a1..7a62e43fc 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, Union import pytest @@ -11,7 +12,7 @@ ) from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule from pie_modules.utils import flatten_dict, list_of_dicts2dict_of_lists -from tests import _config_to_str +from tests import FIXTURES_ROOT, _config_to_str TOKENIZER_NAME_OR_PATH = "bert-base-cased" @@ -85,27 +86,23 @@ def taskmodule(unprepared_taskmodule, positive_documents): return unprepared_taskmodule -@pytest.fixture(scope="module") -def documents_with_negatives(taskmodule, positive_documents): - return list(taskmodule._add_negative_relations(positive_documents)) - - -def test_construct_negative_documents(positive_documents, documents_with_negatives): +def test_construct_negative_documents(taskmodule, positive_documents): assert len(positive_documents) == 2 + docs = list(taskmodule._add_negative_relations(positive_documents)) TEXTS = [ "Entity A works at B.", "And she founded C.", "Bob loves his cat.", "She sleeps a lot.", ] - assert len(documents_with_negatives) == 12 + assert len(docs) == 12 all_scores = [ [coref_rel.score for coref_rel in doc.binary_coref_relations] - for doc in documents_with_negatives + for doc in docs ] - assert documents_with_negatives[0].text == TEXTS[1] - assert documents_with_negatives[0].text_pair == TEXTS[2] - assert documents_with_negatives[0].binary_coref_relations.resolve() == [ + assert docs[0].text == TEXTS[1] + assert docs[0].text_pair == TEXTS[2] + assert docs[0].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "she"), ("PERSON", "Bob"))), ("coref", (("PERSON", "she"), ("ANIMAL", "his cat"))), ("coref", (("COMPANY", "C"), ("PERSON", "Bob"))), @@ -113,9 +110,9 @@ def test_construct_negative_documents(positive_documents, documents_with_negativ ] assert all_scores[0] == [0.0, 0.0, 0.0, 0.0] - assert documents_with_negatives[1].text == TEXTS[1] - assert documents_with_negatives[1].text_pair == TEXTS[0] - assert documents_with_negatives[1].binary_coref_relations.resolve() == [ + assert docs[1].text == TEXTS[1] + assert docs[1].text_pair == TEXTS[0] + assert docs[1].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))), ("coref", (("PERSON", "she"), ("COMPANY", "B"))), ("coref", (("COMPANY", "C"), ("PERSON", "Entity A"))), @@ -123,17 +120,17 @@ def test_construct_negative_documents(positive_documents, documents_with_negativ ] assert all_scores[1] == [1.0, 0.0, 0.0, 0.0] - assert documents_with_negatives[2].text == TEXTS[1] - assert documents_with_negatives[2].text_pair == TEXTS[3] - assert documents_with_negatives[2].binary_coref_relations.resolve() == [ + assert docs[2].text == TEXTS[1] + assert docs[2].text_pair == TEXTS[3] + assert docs[2].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "she"), ("ANIMAL", "She"))), ("coref", (("COMPANY", "C"), ("ANIMAL", "She"))), ] assert all_scores[2] == [0.0, 0.0] - assert documents_with_negatives[3].text == TEXTS[2] - assert documents_with_negatives[3].text_pair == TEXTS[1] - assert documents_with_negatives[3].binary_coref_relations.resolve() == [ + assert docs[3].text == TEXTS[2] + assert docs[3].text_pair == TEXTS[1] + assert docs[3].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "Bob"), ("PERSON", "she"))), ("coref", (("PERSON", "Bob"), ("COMPANY", "C"))), ("coref", (("ANIMAL", "his cat"), ("PERSON", "she"))), @@ -141,9 +138,9 @@ def test_construct_negative_documents(positive_documents, documents_with_negativ ] assert all_scores[3] == [0.0, 0.0, 0.0, 0.0] - assert documents_with_negatives[4].text == TEXTS[2] - assert documents_with_negatives[4].text_pair == TEXTS[0] - assert documents_with_negatives[4].binary_coref_relations.resolve() == [ + assert docs[4].text == TEXTS[2] + assert docs[4].text_pair == TEXTS[0] + assert docs[4].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))), ("coref", (("PERSON", "Bob"), ("COMPANY", "B"))), ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A"))), @@ -151,17 +148,17 @@ def test_construct_negative_documents(positive_documents, documents_with_negativ ] assert all_scores[4] == [0.0, 0.0, 0.0, 0.0] - assert documents_with_negatives[5].text == TEXTS[2] - assert documents_with_negatives[5].text_pair == TEXTS[3] - assert documents_with_negatives[5].binary_coref_relations.resolve() == [ + assert docs[5].text == TEXTS[2] + assert docs[5].text_pair == TEXTS[3] + assert docs[5].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "Bob"), ("ANIMAL", "She"))), ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))), ] assert all_scores[5] == [0.0, 1.0] - assert documents_with_negatives[6].text == TEXTS[0] - assert documents_with_negatives[6].text_pair == TEXTS[1] - assert documents_with_negatives[6].binary_coref_relations.resolve() == [ + assert docs[6].text == TEXTS[0] + assert docs[6].text_pair == TEXTS[1] + assert docs[6].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))), ("coref", (("PERSON", "Entity A"), ("COMPANY", "C"))), ("coref", (("COMPANY", "B"), ("PERSON", "she"))), @@ -169,9 +166,9 @@ def test_construct_negative_documents(positive_documents, documents_with_negativ ] assert all_scores[6] == [1.0, 0.0, 0.0, 0.0] - assert documents_with_negatives[7].text == TEXTS[0] - assert documents_with_negatives[7].text_pair == TEXTS[2] - assert documents_with_negatives[7].binary_coref_relations.resolve() == [ + assert docs[7].text == TEXTS[0] + assert docs[7].text_pair == TEXTS[2] + assert docs[7].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))), ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat"))), ("coref", (("COMPANY", "B"), ("PERSON", "Bob"))), @@ -179,39 +176,60 @@ def test_construct_negative_documents(positive_documents, documents_with_negativ ] assert all_scores[7] == [0.0, 0.0, 0.0, 0.0] - assert documents_with_negatives[8].text == TEXTS[0] - assert documents_with_negatives[8].text_pair == TEXTS[3] - assert documents_with_negatives[8].binary_coref_relations.resolve() == [ + assert docs[8].text == TEXTS[0] + assert docs[8].text_pair == TEXTS[3] + assert docs[8].binary_coref_relations.resolve() == [ ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She"))), ("coref", (("COMPANY", "B"), ("ANIMAL", "She"))), ] assert all_scores[8] == [0.0, 0.0] - assert documents_with_negatives[9].text == TEXTS[3] - assert documents_with_negatives[9].text_pair == TEXTS[1] - assert documents_with_negatives[9].binary_coref_relations.resolve() == [ + assert docs[9].text == TEXTS[3] + assert docs[9].text_pair == TEXTS[1] + assert docs[9].binary_coref_relations.resolve() == [ ("coref", (("ANIMAL", "She"), ("PERSON", "she"))), ("coref", (("ANIMAL", "She"), ("COMPANY", "C"))), ] assert all_scores[9] == [0.0, 0.0] - assert documents_with_negatives[10].text == TEXTS[3] - assert documents_with_negatives[10].text_pair == TEXTS[2] - assert documents_with_negatives[10].binary_coref_relations.resolve() == [ + assert docs[10].text == TEXTS[3] + assert docs[10].text_pair == TEXTS[2] + assert docs[10].binary_coref_relations.resolve() == [ ("coref", (("ANIMAL", "She"), ("PERSON", "Bob"))), ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))), ] assert all_scores[10] == [0.0, 1.0] - assert documents_with_negatives[11].text == TEXTS[3] - assert documents_with_negatives[11].text_pair == TEXTS[0] - assert documents_with_negatives[11].binary_coref_relations.resolve() == [ + assert docs[11].text == TEXTS[3] + assert docs[11].text_pair == TEXTS[0] + assert docs[11].binary_coref_relations.resolve() == [ ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A"))), ("coref", (("ANIMAL", "She"), ("COMPANY", "B"))), ] assert all_scores[11] == [0.0, 0.0] +@pytest.fixture(scope="module") +def documents_with_negatives(taskmodule, positive_documents): + file_name = ( + FIXTURES_ROOT / "taskmodules" / "cross_text_binary_coref" / "documents_with_negatives.json" + ) + + # result = list(taskmodule._add_negative_relations(positive_documents)) + # result_json = [doc.asdict() for doc in result] + # with open(file_name, "w") as f: + # json.dump(result_json, f, indent=2) + + with open(file_name) as f: + result_json = json.load(f) + result = [ + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations.fromdict(doc_json) + for doc_json in result_json + ] + + return result + + @pytest.fixture(scope="module") def task_encodings_without_target(taskmodule, documents_with_negatives): task_encodings = taskmodule.encode_input(documents_with_negatives[0]) From b2f739b88b803e9822791cc943dac8a249f0aee1 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 13:48:32 +0200 Subject: [PATCH 06/49] disentangle tests --- .../test_cross_text_binary_coref.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 7a62e43fc..5b119960a 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -96,10 +96,7 @@ def test_construct_negative_documents(taskmodule, positive_documents): "She sleeps a lot.", ] assert len(docs) == 12 - all_scores = [ - [coref_rel.score for coref_rel in doc.binary_coref_relations] - for doc in docs - ] + all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] assert docs[0].text == TEXTS[1] assert docs[0].text_pair == TEXTS[2] assert docs[0].binary_coref_relations.resolve() == [ @@ -282,15 +279,23 @@ def test_encode_target(task_encodings_without_target, taskmodule): assert target == 0.0 -@pytest.fixture(scope="module", params=[False, True]) -def batch(taskmodule, positive_documents, documents_with_negatives, request): - if request.param: - original_value = taskmodule.add_negative_relations - taskmodule.add_negative_relations = True - task_encodings = taskmodule.encode(positive_documents, encode_target=True)[:4] - taskmodule.add_negative_relations = original_value - else: - task_encodings = taskmodule.encode(documents_with_negatives[0], encode_target=True) +def test_encode_with_add_negative_relations(taskmodule, positive_documents): + original_value = taskmodule.add_negative_relations + taskmodule.add_negative_relations = False + documents_with_negatives = list(taskmodule._add_negative_relations(positive_documents)) + task_encodings1 = taskmodule.encode(documents_with_negatives, encode_target=True) + taskmodule.add_negative_relations = True + task_encodings2 = taskmodule.encode(positive_documents, encode_target=True) + taskmodule.add_negative_relations = original_value + + for task_encoding1, task_encoding2 in zip(task_encodings1, task_encodings2): + torch.testing.assert_close(task_encoding1.inputs, task_encoding2.inputs) + torch.testing.assert_close(task_encoding1.targets, task_encoding2.targets) + + +@pytest.fixture(scope="module") +def batch(taskmodule, positive_documents, documents_with_negatives): + task_encodings = taskmodule.encode(documents_with_negatives[0], encode_target=True) result = taskmodule.collate(task_encodings) return result From 119ef757d19fe21e92681a4138ad40ed63532f26 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 13:54:39 +0200 Subject: [PATCH 07/49] streamline test --- .../test_cross_text_binary_coref.py | 171 +++++++----------- 1 file changed, 68 insertions(+), 103 deletions(-) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 5b119960a..1cab5fac7 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -96,114 +96,79 @@ def test_construct_negative_documents(taskmodule, positive_documents): "She sleeps a lot.", ] assert len(docs) == 12 - all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] - assert docs[0].text == TEXTS[1] - assert docs[0].text_pair == TEXTS[2] - assert docs[0].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "she"), ("PERSON", "Bob"))), - ("coref", (("PERSON", "she"), ("ANIMAL", "his cat"))), - ("coref", (("COMPANY", "C"), ("PERSON", "Bob"))), - ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat"))), - ] - assert all_scores[0] == [0.0, 0.0, 0.0, 0.0] - - assert docs[1].text == TEXTS[1] - assert docs[1].text_pair == TEXTS[0] - assert docs[1].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))), - ("coref", (("PERSON", "she"), ("COMPANY", "B"))), - ("coref", (("COMPANY", "C"), ("PERSON", "Entity A"))), - ("coref", (("COMPANY", "C"), ("COMPANY", "B"))), - ] - assert all_scores[1] == [1.0, 0.0, 0.0, 0.0] - - assert docs[2].text == TEXTS[1] - assert docs[2].text_pair == TEXTS[3] - assert docs[2].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "she"), ("ANIMAL", "She"))), - ("coref", (("COMPANY", "C"), ("ANIMAL", "She"))), - ] - assert all_scores[2] == [0.0, 0.0] - - assert docs[3].text == TEXTS[2] - assert docs[3].text_pair == TEXTS[1] - assert docs[3].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Bob"), ("PERSON", "she"))), - ("coref", (("PERSON", "Bob"), ("COMPANY", "C"))), - ("coref", (("ANIMAL", "his cat"), ("PERSON", "she"))), - ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C"))), - ] - assert all_scores[3] == [0.0, 0.0, 0.0, 0.0] - - assert docs[4].text == TEXTS[2] - assert docs[4].text_pair == TEXTS[0] - assert docs[4].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))), - ("coref", (("PERSON", "Bob"), ("COMPANY", "B"))), - ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A"))), - ("coref", (("ANIMAL", "his cat"), ("COMPANY", "B"))), - ] - assert all_scores[4] == [0.0, 0.0, 0.0, 0.0] + assert all(doc.text in TEXTS for doc in docs) + assert all(doc.text_pair in TEXTS for doc in docs) - assert docs[5].text == TEXTS[2] - assert docs[5].text_pair == TEXTS[3] - assert docs[5].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Bob"), ("ANIMAL", "She"))), - ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))), - ] - assert all_scores[5] == [0.0, 1.0] - - assert docs[6].text == TEXTS[0] - assert docs[6].text_pair == TEXTS[1] - assert docs[6].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))), - ("coref", (("PERSON", "Entity A"), ("COMPANY", "C"))), - ("coref", (("COMPANY", "B"), ("PERSON", "she"))), - ("coref", (("COMPANY", "B"), ("COMPANY", "C"))), - ] - assert all_scores[6] == [1.0, 0.0, 0.0, 0.0] - - assert docs[7].text == TEXTS[0] - assert docs[7].text_pair == TEXTS[2] - assert docs[7].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))), - ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat"))), - ("coref", (("COMPANY", "B"), ("PERSON", "Bob"))), - ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat"))), - ] - assert all_scores[7] == [0.0, 0.0, 0.0, 0.0] - - assert docs[8].text == TEXTS[0] - assert docs[8].text_pair == TEXTS[3] - assert docs[8].binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She"))), - ("coref", (("COMPANY", "B"), ("ANIMAL", "She"))), - ] - assert all_scores[8] == [0.0, 0.0] - - assert docs[9].text == TEXTS[3] - assert docs[9].text_pair == TEXTS[1] - assert docs[9].binary_coref_relations.resolve() == [ - ("coref", (("ANIMAL", "She"), ("PERSON", "she"))), - ("coref", (("ANIMAL", "She"), ("COMPANY", "C"))), - ] - assert all_scores[9] == [0.0, 0.0] + all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] + all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] - assert docs[10].text == TEXTS[3] - assert docs[10].text_pair == TEXTS[2] - assert docs[10].binary_coref_relations.resolve() == [ - ("coref", (("ANIMAL", "She"), ("PERSON", "Bob"))), - ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))), + all_rels_and_scores = [ + list(zip(scores, rels_resolved)) + for scores, rels_resolved in zip(all_scores, all_rels_resolved) ] - assert all_scores[10] == [0.0, 1.0] - assert docs[11].text == TEXTS[3] - assert docs[11].text_pair == TEXTS[0] - assert docs[11].binary_coref_relations.resolve() == [ - ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A"))), - ("coref", (("ANIMAL", "She"), ("COMPANY", "B"))), + assert all_rels_and_scores == [ + [ + (0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob")))), + (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "his cat")))), + (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Bob")))), + (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat")))), + ], + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("PERSON", "she"), ("COMPANY", "B")))), + (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + [ + (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "She")))), + (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "She")))), + ], + [ + (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she")))), + (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "C")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "she")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C")))), + ], + [ + (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "B")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "B")))), + ], + [ + (0.0, ("coref", (("PERSON", "Bob"), ("ANIMAL", "She")))), + (1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))), + ], + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("PERSON", "Entity A"), ("COMPANY", "C")))), + (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + [ + (0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob")))), + (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat")))), + (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "Bob")))), + (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat")))), + ], + [ + (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She")))), + (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "She")))), + ], + [ + (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "she")))), + (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "C")))), + ], + [ + (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Bob")))), + (1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat")))), + ], + [ + (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "B")))), + ], ] - assert all_scores[11] == [0.0, 0.0] @pytest.fixture(scope="module") From 6b90f997d61b0636f6a75ef0098ee5eab5e09675 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 14:08:20 +0200 Subject: [PATCH 08/49] improve test --- .../test_cross_text_binary_coref.py | 162 +++++++++++------- 1 file changed, 99 insertions(+), 63 deletions(-) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 1cab5fac7..cbb9281b6 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -95,79 +95,115 @@ def test_construct_negative_documents(taskmodule, positive_documents): "Bob loves his cat.", "She sleeps a lot.", ] - assert len(docs) == 12 assert all(doc.text in TEXTS for doc in docs) assert all(doc.text_pair in TEXTS for doc in docs) + all_texts = [(doc.text, doc.text_pair) for doc in docs] all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] all_rels_and_scores = [ - list(zip(scores, rels_resolved)) - for scores, rels_resolved in zip(all_scores, all_rels_resolved) + (texts, list(zip(scores, rels_resolved))) + for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) ] assert all_rels_and_scores == [ - [ - (0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob")))), - (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "his cat")))), - (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Bob")))), - (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat")))), - ], - [ - (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("PERSON", "she"), ("COMPANY", "B")))), - (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), - ], - [ - (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "She")))), - (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "She")))), - ], - [ - (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she")))), - (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "C")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "she")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C")))), - ], - [ - (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "B")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "B")))), - ], - [ - (0.0, ("coref", (("PERSON", "Bob"), ("ANIMAL", "She")))), - (1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))), - ], - [ - (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), - (0.0, ("coref", (("PERSON", "Entity A"), ("COMPANY", "C")))), - (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "she")))), - (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), - ], - [ - (0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob")))), - (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat")))), - (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "Bob")))), - (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat")))), - ], - [ - (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She")))), - (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "She")))), - ], - [ - (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "she")))), - (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "C")))), - ], - [ - (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Bob")))), - (1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat")))), - ], - [ - (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "B")))), - ], + ( + ("And she founded C.", "Bob loves his cat."), + [ + (0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob")))), + (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "his cat")))), + (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Bob")))), + (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat")))), + ], + ), + ( + ("And she founded C.", "Entity A works at B."), + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("PERSON", "she"), ("COMPANY", "B")))), + (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + ), + ( + ("And she founded C.", "She sleeps a lot."), + [ + (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "She")))), + (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "She")))), + ], + ), + ( + ("Bob loves his cat.", "And she founded C."), + [ + (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she")))), + (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "C")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "she")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C")))), + ], + ), + ( + ("Bob loves his cat.", "Entity A works at B."), + [ + (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "B")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "B")))), + ], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [ + (0.0, ("coref", (("PERSON", "Bob"), ("ANIMAL", "She")))), + (1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))), + ], + ), + ( + ("Entity A works at B.", "And she founded C."), + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("PERSON", "Entity A"), ("COMPANY", "C")))), + (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + ), + ( + ("Entity A works at B.", "Bob loves his cat."), + [ + (0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob")))), + (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat")))), + (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "Bob")))), + (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat")))), + ], + ), + ( + ("Entity A works at B.", "She sleeps a lot."), + [ + (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She")))), + (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "She")))), + ], + ), + ( + ("She sleeps a lot.", "And she founded C."), + [ + (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "she")))), + (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "C")))), + ], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [ + (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Bob")))), + (1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat")))), + ], + ), + ( + ("She sleeps a lot.", "Entity A works at B."), + [ + (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "B")))), + ], + ), ] From 5067e168b267e8da2b7380a642a16e95ef96ca71 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 14:10:23 +0200 Subject: [PATCH 09/49] create negatives from text to itself (but different spans) --- .../taskmodules/cross_text_binary_coref.py | 5 +++-- .../test_cross_text_binary_coref.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 968be53bf..970455ffa 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -101,8 +101,6 @@ def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable new_docs = [] for text in sorted(text2spans): for text_pair in sorted(text2spans): - if text == text_pair: - continue current_positives = positive_tuples.get((text, text_pair), set()) new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( text=text, text_pair=text_pair @@ -115,6 +113,9 @@ def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable ) for s in sorted(new_doc.labeled_spans): for s_p in sorted(new_doc.labeled_spans_pair): + # exclude relations to itself + if text == text_pair and s.copy() == s_p.copy(): + continue score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) new_doc.binary_coref_relations.append(new_coref_rel) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index cbb9281b6..635a86eee 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -108,6 +108,13 @@ def test_construct_negative_documents(taskmodule, positive_documents): ] assert all_rels_and_scores == [ + ( + ("And she founded C.", "And she founded C."), + [ + (0.0, ("coref", (("PERSON", "she"), ("COMPANY", "C")))), + (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "she")))), + ], + ), ( ("And she founded C.", "Bob loves his cat."), [ @@ -142,6 +149,13 @@ def test_construct_negative_documents(taskmodule, positive_documents): (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C")))), ], ), + ( + ("Bob loves his cat.", "Bob loves his cat."), + [ + (0.0, ("coref", (("PERSON", "Bob"), ("ANIMAL", "his cat")))), + (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "Bob")))), + ], + ), ( ("Bob loves his cat.", "Entity A works at B."), [ @@ -176,6 +190,13 @@ def test_construct_negative_documents(taskmodule, positive_documents): (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat")))), ], ), + ( + ("Entity A works at B.", "Entity A works at B."), + [ + (0.0, ("coref", (("PERSON", "Entity A"), ("COMPANY", "B")))), + (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "Entity A")))), + ], + ), ( ("Entity A works at B.", "She sleeps a lot."), [ @@ -204,6 +225,7 @@ def test_construct_negative_documents(taskmodule, positive_documents): (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "B")))), ], ), + (("She sleeps a lot.", "She sleeps a lot."), []), ] From a93abde0ff92a6935da21fb9b484ace2e4f0cf7f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 14:13:49 +0200 Subject: [PATCH 10/49] restrict candidates by having same entity type --- .../taskmodules/cross_text_binary_coref.py | 2 + .../test_cross_text_binary_coref.py | 98 +++---------------- 2 files changed, 15 insertions(+), 85 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 970455ffa..0b63d7a69 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -116,6 +116,8 @@ def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable # exclude relations to itself if text == text_pair and s.copy() == s_p.copy(): continue + if s.label != s_p.label: + continue score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) new_doc.binary_coref_relations.append(new_coref_rel) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 635a86eee..9a994b773 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -108,123 +108,51 @@ def test_construct_negative_documents(taskmodule, positive_documents): ] assert all_rels_and_scores == [ - ( - ("And she founded C.", "And she founded C."), - [ - (0.0, ("coref", (("PERSON", "she"), ("COMPANY", "C")))), - (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "she")))), - ], - ), + (("And she founded C.", "And she founded C."), []), ( ("And she founded C.", "Bob loves his cat."), - [ - (0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob")))), - (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "his cat")))), - (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Bob")))), - (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat")))), - ], + [(0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob"))))], ), ( ("And she founded C.", "Entity A works at B."), [ (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("PERSON", "she"), ("COMPANY", "B")))), - (0.0, ("coref", (("COMPANY", "C"), ("PERSON", "Entity A")))), (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), ], ), - ( - ("And she founded C.", "She sleeps a lot."), - [ - (0.0, ("coref", (("PERSON", "she"), ("ANIMAL", "She")))), - (0.0, ("coref", (("COMPANY", "C"), ("ANIMAL", "She")))), - ], - ), + (("And she founded C.", "She sleeps a lot."), []), ( ("Bob loves his cat.", "And she founded C."), - [ - (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she")))), - (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "C")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "she")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "C")))), - ], - ), - ( - ("Bob loves his cat.", "Bob loves his cat."), - [ - (0.0, ("coref", (("PERSON", "Bob"), ("ANIMAL", "his cat")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "Bob")))), - ], + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], ), + (("Bob loves his cat.", "Bob loves his cat."), []), ( ("Bob loves his cat.", "Entity A works at B."), - [ - (0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("PERSON", "Bob"), ("COMPANY", "B")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("ANIMAL", "his cat"), ("COMPANY", "B")))), - ], + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], ), ( ("Bob loves his cat.", "She sleeps a lot."), - [ - (0.0, ("coref", (("PERSON", "Bob"), ("ANIMAL", "She")))), - (1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))), - ], + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], ), ( ("Entity A works at B.", "And she founded C."), [ (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), - (0.0, ("coref", (("PERSON", "Entity A"), ("COMPANY", "C")))), - (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "she")))), (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), ], ), ( ("Entity A works at B.", "Bob loves his cat."), - [ - (0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob")))), - (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "his cat")))), - (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "Bob")))), - (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "his cat")))), - ], - ), - ( - ("Entity A works at B.", "Entity A works at B."), - [ - (0.0, ("coref", (("PERSON", "Entity A"), ("COMPANY", "B")))), - (0.0, ("coref", (("COMPANY", "B"), ("PERSON", "Entity A")))), - ], - ), - ( - ("Entity A works at B.", "She sleeps a lot."), - [ - (0.0, ("coref", (("PERSON", "Entity A"), ("ANIMAL", "She")))), - (0.0, ("coref", (("COMPANY", "B"), ("ANIMAL", "She")))), - ], - ), - ( - ("She sleeps a lot.", "And she founded C."), - [ - (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "she")))), - (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "C")))), - ], + [(0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))))], ), + (("Entity A works at B.", "Entity A works at B."), []), + (("Entity A works at B.", "She sleeps a lot."), []), + (("She sleeps a lot.", "And she founded C."), []), ( ("She sleeps a lot.", "Bob loves his cat."), - [ - (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Bob")))), - (1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat")))), - ], - ), - ( - ("She sleeps a lot.", "Entity A works at B."), - [ - (0.0, ("coref", (("ANIMAL", "She"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("ANIMAL", "She"), ("COMPANY", "B")))), - ], + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], ), + (("She sleeps a lot.", "Entity A works at B."), []), (("She sleeps a lot.", "She sleeps a lot."), []), ] From ddadf5e3242f03e37328b88a7a70915b5c4cfa4c Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 15:50:05 +0200 Subject: [PATCH 11/49] make RelationStatisticsMixin ready for multi-label or binary --- src/pie_modules/taskmodules/common/mixins.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/taskmodules/common/mixins.py b/src/pie_modules/taskmodules/common/mixins.py index 876d2f9cc..4a6f647d1 100644 --- a/src/pie_modules/taskmodules/common/mixins.py +++ b/src/pie_modules/taskmodules/common/mixins.py @@ -185,7 +185,10 @@ def finalize_statistics(self): else: raise ValueError(f"unknown key: {key}") for rel in rels_set: - self.increase_counter(key=(key, rel.label)) + # Set "no_relation" as label when the score is zero. We encode negative relations + # in such a way in the case of multi-label or binary (similarity for coref). + label = rel.label if rel.score > 0 else "no_relation" + self.increase_counter(key=(key, label)) for rel in skipped_other: self.increase_counter(key=("skipped_other", rel.label)) From 4768fe0b8588c659f85dbd11413e05e2891a4ca3 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 15:50:57 +0200 Subject: [PATCH 12/49] use RelationStatisticsMixin --- .../taskmodules/cross_text_binary_coref.py | 13 ++++++++++-- .../test_cross_text_binary_coref.py | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 0b63d7a69..0f2e9d900 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -25,6 +25,7 @@ BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) +from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction from pie_modules.utils import list_of_dicts2dict_of_lists @@ -65,7 +66,9 @@ def _get_labels(model_output: ModelTargetType) -> torch.Tensor: @TaskModule.register() -class CrossTextBinaryCorefTaskModule(TaskModuleType, ChangesTokenizerVocabSize): +class CrossTextBinaryCorefTaskModule( + RelationStatisticsMixin, TaskModuleType, ChangesTokenizerVocabSize +): DOCUMENT_TYPE = DocumentType def __init__( @@ -126,18 +129,22 @@ def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable return new_docs def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): + self.reset_statistics() if self.add_negative_relations: if isinstance(documents, DocumentType): documents = [documents] documents = self._add_negative_relations(documents) - return super().encode(documents=documents, **kwargs) + result = super().encode(documents=documents, **kwargs) + self.show_statistics() + return result def encode_input( self, document: DocumentType, is_training: bool = False, ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: + self.collect_all_relations(kind="available", relations=document.binary_coref_relations) tokenizer_kwargs = dict( padding=False, truncation=True, @@ -158,6 +165,7 @@ def encode_input( logger.warning( f"Could not get token offsets for arguments of coref relation: {coref_rel.resolve()}. Skip it." ) + self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) continue task_encodings.append( TaskEncoding( @@ -173,6 +181,7 @@ def encode_input( metadata={"candidate_annotation": coref_rel}, ) ) + self.collect_relation("used", coref_rel) return task_encodings def encode_target( diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 9a994b773..b8961d0dd 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -1,4 +1,5 @@ import json +import logging from typing import Any, Dict, Union import pytest @@ -244,6 +245,26 @@ def test_encode_with_add_negative_relations(taskmodule, positive_documents): torch.testing.assert_close(task_encoding1.targets, task_encoding2.targets) +def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): + caplog.clear() + with caplog.at_level(logging.INFO): + original_values = taskmodule.add_negative_relations, taskmodule.collect_statistics + taskmodule.add_negative_relations = True + taskmodule.collect_statistics = True + taskmodule.encode(positive_documents, encode_target=True) + taskmodule.add_negative_relations, taskmodule.collect_statistics = original_values + + assert len(caplog.messages) == 1 + assert ( + caplog.messages[0] == "statistics:\n" + "| | coref | no_relation | all_relations |\n" + "|:----------|--------:|--------------:|----------------:|\n" + "| available | 4 | 6 | 4 |\n" + "| used | 4 | 6 | 4 |\n" + "| used % | 100 | 100 | 100 |" + ) + + @pytest.fixture(scope="module") def batch(taskmodule, positive_documents, documents_with_negatives): task_encodings = taskmodule.encode(documents_with_negatives[0], encode_target=True) From 5d0c7fa484603b74d362a900223c3feaf9bcfe36 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 16:03:45 +0200 Subject: [PATCH 13/49] rename model to SpanSimilarityModel; add similarity_threshold parameter; set num_indices when "mention_pooling" is used --- src/pie_modules/models/__init__.py | 1 + .../{simple_similarity.py => span_similarity.py} | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) rename src/pie_modules/models/{simple_similarity.py => span_similarity.py} (94%) diff --git a/src/pie_modules/models/__init__.py b/src/pie_modules/models/__init__.py index f64038f80..e454a38e1 100644 --- a/src/pie_modules/models/__init__.py +++ b/src/pie_modules/models/__init__.py @@ -3,6 +3,7 @@ from .simple_generative import SimpleGenerativeModel from .simple_sequence_classification import SimpleSequenceClassificationModel from .simple_token_classification import SimpleTokenClassificationModel +from .span_similarity import SpanSimilarityModel from .span_tuple_classification import SpanTupleClassificationModel from .token_classification_with_seq2seq_encoder_and_crf import ( TokenClassificationModelWithSeq2SeqEncoderAndCrf, diff --git a/src/pie_modules/models/simple_similarity.py b/src/pie_modules/models/span_similarity.py similarity index 94% rename from src/pie_modules/models/simple_similarity.py rename to src/pie_modules/models/span_similarity.py index 6d3766f2c..ee436c1c0 100644 --- a/src/pie_modules/models/simple_similarity.py +++ b/src/pie_modules/models/span_similarity.py @@ -32,7 +32,7 @@ @PyTorchIEModel.register() -class SimpleSimilarityModel( +class SpanSimilarityModel( ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], RequiresModelNameOrPath, ): @@ -63,11 +63,13 @@ def __init__( self, model_name_or_path: str, tokenizer_vocab_size: Optional[int] = None, + similarity_threshold: float = 0.5, classifier_dropout: Optional[float] = None, learning_rate: float = 1e-5, task_learning_rate: Optional[float] = None, warmup_proportion: float = 0.1, - # TODO: use "mention_pooling" per default? + # TODO: use "mention_pooling" per default? But this requires + # to also set num_indices=1 in the pooler_config pooler: Optional[Union[Dict[str, Any], str]] = None, freeze_base_model: bool = False, hidden_dim: Optional[int] = None, @@ -77,6 +79,7 @@ def __init__( self.save_hyperparameters() + self.similarity_threshold = similarity_threshold self.learning_rate = learning_rate self.task_learning_rate = task_learning_rate self.warmup_proportion = warmup_proportion @@ -108,6 +111,10 @@ def __init__( if isinstance(pooler, str): pooler = {"type": pooler} + if pooler is not None: + if pooler["type"] == "mention_pooling": + # we have only one index (span) per input to pool + pooler["num_indices"] = 1 self.pooler_config = pooler or {} self.pooler, pooler_output_dim = get_pooler_and_output_size( config=self.pooler_config, @@ -170,7 +177,7 @@ def forward( return SequenceClassifierOutput(**result) def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - labels = (outputs.logits > 0.5).to(torch.long) + labels = (outputs.logits >= self.similarity_threshold).to(torch.long) return {"labels": labels, "probabilities": outputs.logits} From cc2e8283c3e958294f816c0477521e31cec3185b Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 18:39:08 +0200 Subject: [PATCH 14/49] remove SpanSimilarityModel in favor of new SequencePairSimilarityModelWithPooler --- src/pie_modules/models/__init__.py | 6 +- .../sequence_classification_with_pooler.py | 63 +++++ src/pie_modules/models/span_similarity.py | 217 ------------------ .../taskmodules/cross_text_binary_coref.py | 8 +- .../test_cross_text_binary_coref.py | 12 +- 5 files changed, 77 insertions(+), 229 deletions(-) delete mode 100644 src/pie_modules/models/span_similarity.py diff --git a/src/pie_modules/models/__init__.py b/src/pie_modules/models/__init__.py index e454a38e1..df8f4a035 100644 --- a/src/pie_modules/models/__init__.py +++ b/src/pie_modules/models/__init__.py @@ -1,9 +1,11 @@ -from .sequence_classification_with_pooler import SequenceClassificationModelWithPooler +from .sequence_classification_with_pooler import ( + SequenceClassificationModelWithPooler, + SequencePairSimilarityModelWithPooler, +) from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel from .simple_generative import SimpleGenerativeModel from .simple_sequence_classification import SimpleSequenceClassificationModel from .simple_token_classification import SimpleTokenClassificationModel -from .span_similarity import SpanSimilarityModel from .span_tuple_classification import SpanTupleClassificationModel from .token_classification_with_seq2seq_encoder_and_crf import ( TokenClassificationModelWithSeq2SeqEncoderAndCrf, diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index 98a6965da..be1cdbb1f 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -266,3 +266,66 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: probabilities = torch.sigmoid(outputs.logits) labels = (probabilities > self.multi_label_threshold).to(torch.long) return {"labels": labels, "probabilities": probabilities} + + +@PyTorchIEModel.register() +class SequencePairSimilarityModelWithPooler( + SequenceClassificationModelWithPoolerBase, +): + """TODO. + + Args: + label_threshold: The threshold for the multi-label classifier, i.e. the probability + above which a class is predicted. + **kwargs + """ + + def __init__(self, label_threshold: float = 0.5, **kwargs): + super().__init__(**kwargs) + self.multi_label_threshold = label_threshold + + def setup_classifier(self, pooler_output_dim: int) -> Callable: + return torch.nn.functional.cosine_similarity + + def setup_loss_fct(self): + return nn.BCELoss() + + def forward( + self, + inputs: InputType, + targets: Optional[TargetType] = None, + return_hidden_states: bool = False, + ) -> OutputType: + sanitized_inputs = separate_arguments_by_prefix( + # Note that the order of the prefixes is important because one is a prefix of the other, + # so we need to start with the longer! + arguments=inputs, + prefixes=["pooler_pair_", "pooler_"], + ) + + pooled_output = self.get_pooled_output( + model_inputs=sanitized_inputs["remaining"]["encoding"], + pooler_inputs=sanitized_inputs["pooler_"], + ) + pooled_output_pair = self.get_pooled_output( + model_inputs=sanitized_inputs["remaining"]["encoding_pair"], + pooler_inputs=sanitized_inputs["pooler_pair_"], + ) + + logits = self.classifier(pooled_output, pooled_output_pair) + + result = {"logits": logits} + if targets is not None: + labels = targets["labels"] + loss = self.loss_fct(logits, labels) + result["loss"] = loss + if return_hidden_states: + raise NotImplementedError("return_hidden_states is not yet implemented") + + return SequenceClassifierOutput(**result) + + def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: + # probabilities = torch.sigmoid(outputs.logits) + probabilities = outputs.logits + labels = (probabilities > self.multi_label_threshold).to(torch.long) + return {"labels": labels, "probabilities": probabilities} diff --git a/src/pie_modules/models/span_similarity.py b/src/pie_modules/models/span_similarity.py deleted file mode 100644 index ee436c1c0..000000000 --- a/src/pie_modules/models/span_similarity.py +++ /dev/null @@ -1,217 +0,0 @@ -import logging -from typing import Any, Dict, Iterator, MutableMapping, Optional, Tuple, Union - -import torch -from pytorch_ie.core import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath -from torch import FloatTensor, LongTensor, nn -from torch.nn import BCELoss, Parameter -from torch.optim import AdamW -from transformers import AutoConfig, AutoModel, get_linear_schedule_with_warmup -from transformers.modeling_outputs import SequenceClassifierOutput -from typing_extensions import TypeAlias - -from .common import ModelWithBoilerplate -from .components.pooler import get_pooler_and_output_size - -# model inputs / outputs / targets -InputType: TypeAlias = MutableMapping[str, LongTensor] -OutputType: TypeAlias = SequenceClassifierOutput -TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - - -HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = { - "albert": "classifier_dropout_prob", - "distilbert": "seq_classif_dropout", -} - -logger = logging.getLogger(__name__) - - -@PyTorchIEModel.register() -class SpanSimilarityModel( - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], - RequiresModelNameOrPath, -): - """TODO. - - Args: - model_name_or_path: The name or path of the HuggingFace model to use. - tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's - tokenizer embeddings are resized to this size. - classifier_dropout: The dropout probability for the classifier. If not provided, the - dropout probability is taken from the Huggingface model config. - learning_rate: The learning rate for the optimizer. - task_learning_rate: The learning rate for the task-specific parameters. If None, the - learning rate for all parameters is set to `learning_rate`. - warmup_proportion: The proportion of steps to warm up the learning rate. - multi_label: If True, the model is trained as a multi-label classifier. - multi_label_threshold: The threshold for the multi-label classifier, i.e. the probability - above which a class is predicted. - pooler: The pooler configuration. If None, CLS token pooling is used. - freeze_base_model: If True, the base model parameters are frozen. - base_model_prefix: The prefix of the base model parameters when using a task_learning_rate - or freeze_base_model. If None, the base_model_prefix of the model is used. - **kwargs: Additional keyword arguments passed to the parent class, - see :class:`ModelWithBoilerplate`. - """ - - def __init__( - self, - model_name_or_path: str, - tokenizer_vocab_size: Optional[int] = None, - similarity_threshold: float = 0.5, - classifier_dropout: Optional[float] = None, - learning_rate: float = 1e-5, - task_learning_rate: Optional[float] = None, - warmup_proportion: float = 0.1, - # TODO: use "mention_pooling" per default? But this requires - # to also set num_indices=1 in the pooler_config - pooler: Optional[Union[Dict[str, Any], str]] = None, - freeze_base_model: bool = False, - hidden_dim: Optional[int] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.save_hyperparameters() - - self.similarity_threshold = similarity_threshold - self.learning_rate = learning_rate - self.task_learning_rate = task_learning_rate - self.warmup_proportion = warmup_proportion - self.freeze_base_model = freeze_base_model - - config = AutoConfig.from_pretrained(model_name_or_path) - if self.is_from_pretrained: - self.model = AutoModel.from_config(config=config) - else: - self.model = AutoModel.from_pretrained(model_name_or_path, config=config) - - if tokenizer_vocab_size is not None: - self.model.resize_token_embeddings(tokenizer_vocab_size) - - if self.freeze_base_model: - for param in self.model.parameters(): - param.requires_grad = False - - if classifier_dropout is None: - # Get the classifier dropout value from the Huggingface model config. - # This is a bit of a mess since some Configs use different variable names or change the semantics - # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a - # general one for embeddings, encoder and pooler). - classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get( - config.model_type, "classifier_dropout" - ) - classifier_dropout = getattr(config, classifier_dropout_attr) or 0.0 - self.dropout = nn.Dropout(classifier_dropout) - - if isinstance(pooler, str): - pooler = {"type": pooler} - if pooler is not None: - if pooler["type"] == "mention_pooling": - # we have only one index (span) per input to pool - pooler["num_indices"] = 1 - self.pooler_config = pooler or {} - self.pooler, pooler_output_dim = get_pooler_and_output_size( - config=self.pooler_config, - input_dim=config.hidden_size, - ) - if hidden_dim is not None: - self.classifier = nn.Linear(pooler_output_dim, hidden_dim) - else: - self.classifier = None - - # TODO: is this ok? - self.loss_fct = BCELoss() - - def get_pooled_output(self, model_inputs, pooler_inputs): - output = self.model(**model_inputs) - hidden_state = output.last_hidden_state - pooled_output = self.pooler(hidden_state, **pooler_inputs) - pooled_output = self.dropout(pooled_output) - if self.classifier is not None: - return self.classifier(pooled_output) - return pooled_output - - def forward( - self, - inputs: InputType, - targets: Optional[TargetType] = None, - return_hidden_states: bool = False, - ) -> OutputType: - model_inputs = None - model_inputs_pair = None - pooler_inputs = {} - pooler_inputs_pair = {} - for k, v in inputs.items(): - if k.startswith("pooler_") and k.endswith("_pair"): - k_target = k[len("pooler_") : -len("_pair")] - pooler_inputs_pair[k_target] = v - elif k.startswith("pooler_"): - k_target = k[len("pooler_") :] - pooler_inputs[k_target] = v - elif k == "encoding": - model_inputs = v - elif k == "encoding_pair": - model_inputs_pair = v - else: - raise ValueError(f"unexpected model input: {k}") - - pooled_output = self.get_pooled_output(model_inputs, pooler_inputs) - pooled_output_pair = self.get_pooled_output(model_inputs_pair, pooler_inputs_pair) - - logits = torch.nn.functional.cosine_similarity(pooled_output, pooled_output_pair) - - result = {"logits": logits} - if targets is not None: - labels = targets["labels"] - loss = self.loss_fct(logits, labels) - result["loss"] = loss - if return_hidden_states: - raise NotImplementedError("return_hidden_states is not yet implemented") - - return SequenceClassifierOutput(**result) - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - labels = (outputs.logits >= self.similarity_threshold).to(torch.long) - - return {"labels": labels, "probabilities": outputs.logits} - - def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - if prefix: - prefix = f"{prefix}." - return self.model.named_parameters(prefix=f"{prefix}model") - - def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - if prefix: - prefix = f"{prefix}." - base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys() - for name, param in self.named_parameters(prefix=prefix): - if name not in base_model_parameter_names: - yield name, param - - def configure_optimizers(self): - if self.task_learning_rate is not None: - base_model_params = (param for name, param in self.base_model_named_parameters()) - task_params = (param for name, param in self.task_named_parameters()) - optimizer = AdamW( - [ - {"params": base_model_params, "lr": self.learning_rate}, - {"params": task_params, "lr": self.task_learning_rate}, - ] - ) - else: - optimizer = AdamW(self.parameters(), lr=self.learning_rate) - - if self.warmup_proportion > 0.0: - stepping_batches = self.trainer.estimated_stepping_batches - scheduler = get_linear_schedule_with_warmup( - optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 0f2e9d900..c311255ea 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -175,8 +175,8 @@ def encode_input( "encoding_pair": encoding_pair, "pooler_start_indices": start, "pooler_end_indices": end, - "pooler_start_indices_pair": start_pair, - "pooler_end_indices_pair": end_pair, + "pooler_pair_start_indices": start_pair, + "pooler_pair_end_indices": end_pair, }, metadata={"candidate_annotation": coref_rel}, ) @@ -201,13 +201,13 @@ def collate( ) inputs = { - k: self.tokenizer.pad(v, return_tensors="pt") + k: self.tokenizer.pad(v, return_tensors="pt").data if k in ["encoding", "encoding_pair"] else torch.tensor(v) for k, v in inputs_dict.items() } for k, v in inputs.items(): - if k.startswith("pooler_start_indices") or k.startswith("pooler_end_indices"): + if k.startswith("pooler_") and k.endswith("_indices"): inputs[k] = v.unsqueeze(-1) if not task_encodings[0].has_targets: diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index b8961d0dd..29830b4ac 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -218,8 +218,8 @@ def test_encode_input(task_encodings_without_target, taskmodule): toks[start:end] for toks, start, end in zip( tokens_pair, - inputs_dict["pooler_start_indices_pair"], - inputs_dict["pooler_end_indices_pair"], + inputs_dict["pooler_pair_start_indices"], + inputs_dict["pooler_pair_end_indices"], ) ] assert span_tokens == [["she"], ["she"], ["C"], ["C"]] @@ -279,10 +279,10 @@ def test_collate(batch, taskmodule): assert set(inputs) == { "pooler_end_indices", "encoding_pair", - "pooler_end_indices_pair", + "pooler_pair_end_indices", "pooler_start_indices", "encoding", - "pooler_start_indices_pair", + "pooler_pair_start_indices", } torch.testing.assert_close( inputs["encoding"]["input_ids"], @@ -325,10 +325,10 @@ def test_collate(batch, taskmodule): torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [2], [4], [4]])) torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [3], [5], [5]])) torch.testing.assert_close( - inputs["pooler_start_indices_pair"], torch.tensor([[1], [3], [1], [3]]) + inputs["pooler_pair_start_indices"], torch.tensor([[1], [3], [1], [3]]) ) torch.testing.assert_close( - inputs["pooler_end_indices_pair"], torch.tensor([[2], [5], [2], [5]]) + inputs["pooler_pair_end_indices"], torch.tensor([[2], [5], [2], [5]]) ) torch.testing.assert_close(targets, {"labels": torch.tensor([0.0, 0.0, 0.0, 0.0])}) From 64fa61eb63e54d194d56129dece2e9588853f2a6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 18:57:22 +0200 Subject: [PATCH 15/49] add tests for SequencePairSimilarityModelWithPooler --- ...uence_pair_similarity_model_with_pooler.py | 334 ++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 tests/models/test_sequence_pair_similarity_model_with_pooler.py diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py new file mode 100644 index 000000000..6ad20e126 --- /dev/null +++ b/tests/models/test_sequence_pair_similarity_model_with_pooler.py @@ -0,0 +1,334 @@ +from typing import Dict + +import pytest +import torch +from pytorch_lightning import Trainer +from torch import LongTensor, tensor +from torch.optim.lr_scheduler import LambdaLR +from transformers.modeling_outputs import SequenceClassifierOutput + +from pie_modules.models import SequencePairSimilarityModelWithPooler +from pie_modules.models.sequence_classification_with_pooler import OutputType +from tests.models import trunc_number + +POOLER = {"type": "mention_pooling", "num_indices": 1} + + +@pytest.fixture +def inputs() -> Dict[str, LongTensor]: + result_dict = { + "encoding": { + "input_ids": tensor( + [ + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + ] + ), + "token_type_ids": tensor( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + "attention_mask": tensor( + [ + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ] + ), + }, + "encoding_pair": { + "input_ids": tensor( + [ + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + ] + ), + "token_type_ids": tensor( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + "attention_mask": tensor( + [ + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ] + ), + }, + "pooler_start_indices": tensor([[2], [2], [4], [4]]), + "pooler_end_indices": tensor([[3], [3], [5], [5]]), + "pooler_pair_start_indices": tensor([[1], [3], [1], [3]]), + "pooler_pair_end_indices": tensor([[2], [5], [2], [5]]), + } + + return result_dict + + +@pytest.fixture +def targets() -> Dict[str, LongTensor]: + return {"labels": tensor([0.0, 0.0, 0.0, 0.0])} + + +@pytest.fixture +def model() -> SequencePairSimilarityModelWithPooler: + torch.manual_seed(42) + result = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + pooler=POOLER, + ) + return result + + +def test_model(model): + assert model is not None + named_parameters = dict(model.named_parameters()) + parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} + parameter_means_expected = { + "model.embeddings.word_embeddings.weight": 0.0031152, + "model.embeddings.position_embeddings.weight": 5.5e-05, + "model.embeddings.token_type_embeddings.weight": -0.0015419, + "model.embeddings.LayerNorm.weight": 1.312345, + "model.embeddings.LayerNorm.bias": -0.0294608, + "model.encoder.layer.0.attention.self.query.weight": -0.0003949, + "model.encoder.layer.0.attention.self.query.bias": 0.0185744, + "model.encoder.layer.0.attention.self.key.weight": 0.0003863, + "model.encoder.layer.0.attention.self.key.bias": 0.0020557, + "model.encoder.layer.0.attention.self.value.weight": 4.22e-05, + "model.encoder.layer.0.attention.self.value.bias": 0.0065417, + "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05, + "model.encoder.layer.0.attention.output.dense.bias": 0.0007209, + "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, + "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, + "model.encoder.layer.0.intermediate.dense.weight": -0.0011731, + "model.encoder.layer.0.intermediate.dense.bias": -0.1219958, + "model.encoder.layer.0.output.dense.weight": -0.0002212, + "model.encoder.layer.0.output.dense.bias": -0.0013031, + "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648, + "model.encoder.layer.0.output.LayerNorm.bias": 0.005295, + "model.encoder.layer.1.attention.self.query.weight": -0.0007321, + "model.encoder.layer.1.attention.self.query.bias": -0.0358397, + "model.encoder.layer.1.attention.self.key.weight": 0.0001333, + "model.encoder.layer.1.attention.self.key.bias": 0.0045062, + "model.encoder.layer.1.attention.self.value.weight": 0.0001012, + "model.encoder.layer.1.attention.self.value.bias": -0.0007094, + "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05, + "model.encoder.layer.1.attention.output.dense.bias": 0.0041446, + "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, + "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, + "model.encoder.layer.1.intermediate.dense.weight": -0.001344, + "model.encoder.layer.1.intermediate.dense.bias": -0.1247257, + "model.encoder.layer.1.output.dense.weight": -5.32e-05, + "model.encoder.layer.1.output.dense.bias": 0.000677, + "model.encoder.layer.1.output.LayerNorm.weight": 1.017162, + "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442, + "model.pooler.dense.weight": 0.0001295, + "model.pooler.dense.bias": -0.0052078, + "pooler.missing_embeddings": 0.0812017, + } + assert parameter_means == parameter_means_expected + + +def test_model_pickleable(model): + import pickle + + pickle.dumps(model) + + +@pytest.fixture +def model_output(model, inputs) -> OutputType: + # set seed to make sure the output is deterministic + torch.manual_seed(42) + return model(inputs) + + +def test_forward_logits(model_output, inputs): + assert isinstance(model_output, SequenceClassifierOutput) + + logits = model_output.logits + + torch.testing.assert_close( + logits, + torch.tensor( + [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] + ), + ) + + +def test_decode(model, model_output, inputs): + decoded = model.decode(inputs=inputs, outputs=model_output) + assert isinstance(decoded, dict) + assert set(decoded) == {"labels", "probabilities"} + labels = decoded["labels"] + torch.testing.assert_close( + labels, + torch.tensor([1, 1, 1, 1]), + ) + probabilities = decoded["probabilities"] + torch.testing.assert_close( + probabilities, + torch.tensor( + [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] + ), + ) + + +@pytest.fixture +def batch(inputs, targets): + return inputs, targets + + +def test_training_step(batch, model): + # set the seed to make sure the loss is deterministic + torch.manual_seed(42) + loss = model.training_step(batch, batch_idx=0) + assert loss is not None + torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) + + +def test_validation_step(batch, model): + # set the seed to make sure the loss is deterministic + torch.manual_seed(42) + loss = model.validation_step(batch, batch_idx=0) + assert loss is not None + torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) + + +def test_test_step(batch, model): + # set the seed to make sure the loss is deterministic + torch.manual_seed(42) + loss = model.test_step(batch, batch_idx=0) + assert loss is not None + torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) + + +def test_base_model_named_parameters(model): + base_model_named_parameters = dict(model.base_model_named_parameters()) + assert set(base_model_named_parameters) == { + "model.pooler.dense.bias", + "model.encoder.layer.0.intermediate.dense.weight", + "model.encoder.layer.0.intermediate.dense.bias", + "model.encoder.layer.1.attention.output.dense.weight", + "model.encoder.layer.1.attention.output.LayerNorm.weight", + "model.encoder.layer.1.attention.self.query.weight", + "model.encoder.layer.1.output.dense.weight", + "model.encoder.layer.0.output.dense.bias", + "model.encoder.layer.1.intermediate.dense.bias", + "model.encoder.layer.1.attention.self.value.bias", + "model.encoder.layer.0.attention.output.dense.weight", + "model.encoder.layer.0.attention.self.query.bias", + "model.encoder.layer.0.attention.self.value.bias", + "model.encoder.layer.1.output.dense.bias", + "model.encoder.layer.1.attention.self.query.bias", + "model.encoder.layer.1.attention.output.LayerNorm.bias", + "model.encoder.layer.0.attention.self.query.weight", + "model.encoder.layer.0.attention.output.LayerNorm.bias", + "model.encoder.layer.0.attention.self.key.bias", + "model.encoder.layer.1.intermediate.dense.weight", + "model.encoder.layer.1.output.LayerNorm.bias", + "model.encoder.layer.1.output.LayerNorm.weight", + "model.encoder.layer.0.attention.self.key.weight", + "model.encoder.layer.1.attention.output.dense.bias", + "model.encoder.layer.0.attention.output.dense.bias", + "model.embeddings.LayerNorm.bias", + "model.encoder.layer.0.attention.self.value.weight", + "model.encoder.layer.0.attention.output.LayerNorm.weight", + "model.embeddings.token_type_embeddings.weight", + "model.encoder.layer.0.output.LayerNorm.weight", + "model.embeddings.position_embeddings.weight", + "model.encoder.layer.1.attention.self.key.bias", + "model.embeddings.LayerNorm.weight", + "model.encoder.layer.0.output.LayerNorm.bias", + "model.encoder.layer.1.attention.self.key.weight", + "model.pooler.dense.weight", + "model.encoder.layer.0.output.dense.weight", + "model.embeddings.word_embeddings.weight", + "model.encoder.layer.1.attention.self.value.weight", + } + + +def test_task_named_parameters(model): + task_named_parameters = dict(model.task_named_parameters()) + assert set(task_named_parameters) == { + "pooler.missing_embeddings", + } + + +def test_configure_optimizers_with_warmup(): + model = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + ) + model.trainer = Trainer(max_epochs=10) + optimizers_and_schedulers = model.configure_optimizers() + assert len(optimizers_and_schedulers) == 2 + optimizers, schedulers = optimizers_and_schedulers + assert len(optimizers) == 1 + assert len(schedulers) == 1 + optimizer = optimizers[0] + assert optimizer is not None + assert isinstance(optimizer, torch.optim.AdamW) + assert optimizer.defaults["lr"] == 1e-05 + assert optimizer.defaults["weight_decay"] == 0.01 + assert optimizer.defaults["eps"] == 1e-08 + + scheduler = schedulers[0] + assert isinstance(scheduler, dict) + assert set(scheduler) == {"scheduler", "interval"} + assert isinstance(scheduler["scheduler"], LambdaLR) + + +def test_configure_optimizers_with_task_learning_rate(monkeypatch): + model = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + learning_rate=1e-5, + task_learning_rate=1e-3, + # disable warmup to make sure the scheduler is not added which would set the learning rate + # to 0 + warmup_proportion=0.0, + ) + optimizer = model.configure_optimizers() + assert optimizer is not None + assert isinstance(optimizer, torch.optim.AdamW) + assert len(optimizer.param_groups) == 2 + # base model parameters + param_group = optimizer.param_groups[0] + assert len(param_group["params"]) == 39 + assert param_group["lr"] == 1e-5 + # classifier head parameters - there is no head + param_group = optimizer.param_groups[1] + assert len(param_group["params"]) == 0 + assert param_group["lr"] == 1e-3 + # ensure that all parameters are covered + assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set( + model.parameters() + ) + + +def test_freeze_base_model(monkeypatch, inputs, targets): + model = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + freeze_base_model=True, + # disable warmup to make sure the scheduler is not added which would set the learning rate + # to 0 + warmup_proportion=0.0, + ) + base_model_params = [param for name, param in model.base_model_named_parameters()] + task_params = [param for name, param in model.task_named_parameters()] + assert len(base_model_params) + len(task_params) == len(list(model.parameters())) + for param in base_model_params: + assert not param.requires_grad + for param in task_params: + assert param.requires_grad From 41af7cef58d9334a75d80c7e948d921745c00f5b Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 19:02:13 +0200 Subject: [PATCH 16/49] use mention pooling per default --- .../models/sequence_classification_with_pooler.py | 7 +++++-- .../test_sequence_pair_similarity_model_with_pooler.py | 7 ++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index be1cdbb1f..49552eb9a 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -280,8 +280,11 @@ class SequencePairSimilarityModelWithPooler( **kwargs """ - def __init__(self, label_threshold: float = 0.5, **kwargs): - super().__init__(**kwargs) + def __init__(self, label_threshold: float = 0.5, pooler: Optional[Union[Dict[str, Any], str]] = None, **kwargs): + if pooler is None: + # use mention pooling per default + pooler = {"type": "mention_pooling", "num_indices": 1} + super().__init__(pooler=pooler, **kwargs) self.multi_label_threshold = label_threshold def setup_classifier(self, pooler_output_dim: int) -> Callable: diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py index 6ad20e126..2c8faefeb 100644 --- a/tests/models/test_sequence_pair_similarity_model_with_pooler.py +++ b/tests/models/test_sequence_pair_similarity_model_with_pooler.py @@ -11,8 +11,6 @@ from pie_modules.models.sequence_classification_with_pooler import OutputType from tests.models import trunc_number -POOLER = {"type": "mention_pooling", "num_indices": 1} - @pytest.fixture def inputs() -> Dict[str, LongTensor]: @@ -88,7 +86,6 @@ def model() -> SequencePairSimilarityModelWithPooler: torch.manual_seed(42) result = SequencePairSimilarityModelWithPooler( model_name_or_path="prajjwal1/bert-tiny", - pooler=POOLER, ) return result @@ -307,9 +304,9 @@ def test_configure_optimizers_with_task_learning_rate(monkeypatch): param_group = optimizer.param_groups[0] assert len(param_group["params"]) == 39 assert param_group["lr"] == 1e-5 - # classifier head parameters - there is no head + # classifier head parameters - there is just the default embedding (which is not used) param_group = optimizer.param_groups[1] - assert len(param_group["params"]) == 0 + assert len(param_group["params"]) == 1 assert param_group["lr"] == 1e-3 # ensure that all parameters are covered assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set( From ad896175430c86092965c01108d238d7b862b4c5 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 19:33:22 +0200 Subject: [PATCH 17/49] make pre-commit happy --- .../models/sequence_classification_with_pooler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index 49552eb9a..61d2bc26c 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -280,7 +280,12 @@ class SequencePairSimilarityModelWithPooler( **kwargs """ - def __init__(self, label_threshold: float = 0.5, pooler: Optional[Union[Dict[str, Any], str]] = None, **kwargs): + def __init__( + self, + label_threshold: float = 0.5, + pooler: Optional[Union[Dict[str, Any], str]] = None, + **kwargs, + ): if pooler is None: # use mention pooling per default pooler = {"type": "mention_pooling", "num_indices": 1} From 2a974ea2ab75997cdc93d479c33a91d6e616601c Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 19:35:10 +0200 Subject: [PATCH 18/49] implement unbatch_output() and create_annotations_from_output() --- .../taskmodules/cross_text_binary_coref.py | 17 +++++-- .../test_cross_text_binary_coref.py | 48 ++++++++++++++++++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index c311255ea..7a39ae016 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -5,6 +5,7 @@ Dict, Iterable, Iterator, + List, Optional, Sequence, Tuple, @@ -43,7 +44,8 @@ class TaskOutputType(TypedDict, total=False): - scores: Sequence[str] + score: float + is_valid: bool ModelInputType: TypeAlias = Dict[str, torch.Tensor] @@ -224,11 +226,20 @@ def configure_model_metric(self, stage: str) -> Metric: ) def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: - raise NotImplementedError() + label_ids = model_output["labels"].detach().cpu().tolist() + probabilities = model_output["probabilities"].detach().cpu().tolist() + result: List[TaskOutputType] = [ + {"is_valid": label_id != 0, "score": prob} + for label_id, prob in zip(label_ids, probabilities) + ] + return result def create_annotations_from_output( self, task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], task_output: TaskOutputType, ) -> Iterator[Tuple[str, Annotation]]: - raise NotImplementedError() + if task_output["is_valid"]: + score = task_output["score"] + new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score) + yield "binary_coref_relations", new_coref_rel diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 29830b4ac..40bc5c889 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -266,8 +266,12 @@ def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): @pytest.fixture(scope="module") -def batch(taskmodule, positive_documents, documents_with_negatives): - task_encodings = taskmodule.encode(documents_with_negatives[0], encode_target=True) +def task_encodings(taskmodule, documents_with_negatives): + return taskmodule.encode(documents_with_negatives[0], encode_target=True) + + +@pytest.fixture(scope="module") +def batch(taskmodule, task_encodings): result = taskmodule.collate(task_encodings) return result @@ -334,6 +338,46 @@ def test_collate(batch, taskmodule): torch.testing.assert_close(targets, {"labels": torch.tensor([0.0, 0.0, 0.0, 0.0])}) +@pytest.fixture(scope="module") +def unbatched_output(taskmodule): + model_output = { + "labels": torch.tensor([1, 1, 1, 1]), + "probabilities": torch.tensor( + [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] + ), + } + return taskmodule.unbatch_output(model_output=model_output) + + +def test_unbatch_output(unbatched_output, taskmodule): + assert len(unbatched_output) == 4 + assert unbatched_output == [ + {"is_valid": True, "score": 0.5338148474693298}, + {"is_valid": True, "score": 0.5866107940673828}, + {"is_valid": True, "score": 0.5076886415481567}, + {"is_valid": True, "score": 0.5946245789527893}, + ] + + +def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_output): + all_new_annotations = [] + for task_encoding, task_output in zip(task_encodings, unbatched_output): + for new_annotation in taskmodule.create_annotations_from_output( + task_encoding=task_encoding, task_output=task_output + ): + all_new_annotations.append(new_annotation) + assert all(layer_name == "binary_coref_relations" for layer_name, ann in all_new_annotations) + resolve_annotations_with_scores = [ + (round(ann.score, 4), ann.resolve()) for layer_name, ann in all_new_annotations + ] + assert resolve_annotations_with_scores == [ + (0.5338, ("coref", (("PERSON", "she"), ("PERSON", "Bob")))), + (0.5866, ("coref", (("PERSON", "she"), ("ANIMAL", "his cat")))), + (0.5077, ("coref", (("COMPANY", "C"), ("PERSON", "Bob")))), + (0.5946, ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat")))), + ] + + def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: if isinstance(metric_or_collection, Metric): return flatten_dict(metric_or_collection.metric_state) From 6a2e8787bb9bfa6b7716c958b2404571bb5a8cf8 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 20:39:09 +0200 Subject: [PATCH 19/49] implement long text handling --- .../taskmodules/cross_text_binary_coref.py | 87 ++++++++++++++++--- .../test_cross_text_binary_coref.py | 33 +++++++ 2 files changed, 106 insertions(+), 14 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 7a39ae016..c9ef53b61 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -1,3 +1,4 @@ +import copy import logging from collections import defaultdict from typing import ( @@ -15,11 +16,13 @@ import torch from pytorch_ie import Annotation +from pytorch_ie.annotations import Span from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize +from pytorch_ie.utils.window import get_window_around_slice from torchmetrics import Metric, MetricCollection from torchmetrics.classification import BinaryAUROC -from transformers import AutoTokenizer +from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias from pie_modules.document.types import ( @@ -63,6 +66,16 @@ class TaskOutputType(TypedDict, total=False): ] +class SpanNotAlignedWithTokenException(Exception): + def __init__(self, span): + self.span = span + + +class SpanDoesNotFitIntoAvailableWindow(Exception): + def __init__(self, span): + self.span = span + + def _get_labels(model_output: ModelTargetType) -> torch.Tensor: return model_output["labels"] @@ -77,6 +90,7 @@ def __init__( self, tokenizer_name_or_path: str, add_negative_relations: bool = False, + max_window: Optional[int] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -85,6 +99,8 @@ def __init__( self.add_negative_relations = add_negative_relations self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length + self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable[DocumentType]: positive_tuples = defaultdict(set) @@ -141,6 +157,35 @@ def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwarg self.show_statistics() return result + def truncate_encoding_around_span( + self, encoding: BatchEncoding, char_span: Span + ) -> Tuple[Dict[str, List[int]], Span]: + input_ids = copy.deepcopy(encoding["input_ids"]) + + token_start = encoding.char_to_token(char_span.start) + token_end_before = encoding.char_to_token(char_span.end - 1) + if token_start is None or token_end_before is None: + raise SpanNotAlignedWithTokenException(span=char_span) + token_end = token_end_before + 1 + + # truncate input_ids and shift token_start and token_end + if len(input_ids) > self.available_window: + window_slice = get_window_around_slice( + slice=[token_start, token_end], + max_window_size=self.available_window, + available_input_length=len(input_ids), + ) + if window_slice is None: + raise SpanDoesNotFitIntoAvailableWindow(span=(token_start, token_end)) + window_start, window_end = window_slice + input_ids = input_ids[window_start:window_end] + token_start -= window_start + token_end -= window_start + + truncated_encoding = self.tokenizer.prepare_for_model(ids=input_ids) + + return truncated_encoding, Span(start=token_start, end=token_end) + def encode_input( self, document: DocumentType, @@ -152,33 +197,47 @@ def encode_input( truncation=True, max_length=self.tokenizer.model_max_length, return_offsets_mapping=False, - add_special_tokens=True, + add_special_tokens=False, ) encoding = self.tokenizer(text=document.text, **tokenizer_kwargs) encoding_pair = self.tokenizer(text=document.text_pair, **tokenizer_kwargs) task_encodings = [] for coref_rel in document.binary_coref_relations: - start = encoding.char_to_token(coref_rel.head.start) - end = encoding.char_to_token(coref_rel.head.end - 1) + 1 - start_pair = encoding_pair.char_to_token(coref_rel.tail.start) - end_pair = encoding_pair.char_to_token(coref_rel.tail.end - 1) + 1 - if any(offset is None for offset in [start, end, start_pair, end_pair]): + try: + current_encoding, token_span = self.truncate_encoding_around_span( + encoding=encoding, char_span=coref_rel.head + ) + current_encoding_pair, token_span_pair = self.truncate_encoding_around_span( + encoding=encoding_pair, char_span=coref_rel.tail + ) + except SpanNotAlignedWithTokenException as e: logger.warning( - f"Could not get token offsets for arguments of coref relation: {coref_rel.resolve()}. Skip it." + f"Could not get token offsets for argument ({e.span}) of coref relation: " + f"{coref_rel.resolve()}. Skip it." ) self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) continue + except SpanDoesNotFitIntoAvailableWindow as e: + logger.warning( + f"Argument span [{e.span}] does not fit into available token window " + f"({self.available_window}). Skip it." + ) + self.collect_relation( + kind="skipped_span_does_not_fit_into_window", relation=coref_rel + ) + continue + task_encodings.append( TaskEncoding( document=document, inputs={ - "encoding": encoding, - "encoding_pair": encoding_pair, - "pooler_start_indices": start, - "pooler_end_indices": end, - "pooler_pair_start_indices": start_pair, - "pooler_pair_end_indices": end_pair, + "encoding": current_encoding, + "encoding_pair": current_encoding_pair, + "pooler_start_indices": token_span.start, + "pooler_end_indices": token_span.end, + "pooler_pair_start_indices": token_span_pair.start, + "pooler_pair_end_indices": token_span_pair.end, }, metadata={"candidate_annotation": coref_rel}, ) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 40bc5c889..252c88b7c 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -265,6 +265,39 @@ def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): ) +def test_encode_with_windowing(documents_with_negatives, caplog): + tokenizer_name_or_path = "bert-base-cased" + taskmodule = CrossTextBinaryCorefTaskModule( + tokenizer_name_or_path=tokenizer_name_or_path, + max_window=4, + collect_statistics=True, + ) + assert not taskmodule.is_from_pretrained + taskmodule.prepare(documents_with_negatives) + + assert len(documents_with_negatives) == 12 + caplog.clear() + with caplog.at_level(logging.INFO): + task_encodings = taskmodule.encode(documents_with_negatives) + assert len(caplog.messages) > 0 + assert ( + caplog.messages[-1] == "statistics:\n" + "| | coref | no_relation | all_relations |\n" + "|:--------------------------------------|--------:|--------------:|----------------:|\n" + "| available | 4 | 32 | 4 |\n" + "| skipped_span_does_not_fit_into_window | 2 | 8 | 2 |\n" + "| used | 2 | 24 | 2 |\n" + "| used % | 50 | 75 | 50 |" + ) + + assert len(task_encodings) == 26 + for task_encoding in task_encodings: + for k, v in task_encoding.inputs["encoding"].items(): + assert len(v) <= taskmodule.max_window + for k, v in task_encoding.inputs["encoding_pair"].items(): + assert len(v) <= taskmodule.max_window + + @pytest.fixture(scope="module") def task_encodings(taskmodule, documents_with_negatives): return taskmodule.encode(documents_with_negatives[0], encode_target=True) From 617be19f31cb7e152c695b128ceb281d4851e2b2 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 20:49:30 +0200 Subject: [PATCH 20/49] set default label_threshold for SequencePairSimilarityModelWithPooler to 0.9 --- .../models/sequence_classification_with_pooler.py | 14 ++++++++------ ...t_sequence_pair_similarity_model_with_pooler.py | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index 61d2bc26c..b72c7fd2d 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -275,27 +275,29 @@ class SequencePairSimilarityModelWithPooler( """TODO. Args: - label_threshold: The threshold for the multi-label classifier, i.e. the probability - above which a class is predicted. + label_threshold: The threshold above which score the spans are considered as similar. + pooler: **kwargs """ def __init__( self, - label_threshold: float = 0.5, + label_threshold: float = 0.9, pooler: Optional[Union[Dict[str, Any], str]] = None, **kwargs, ): if pooler is None: - # use mention pooling per default + # use (max) mention pooling per default pooler = {"type": "mention_pooling", "num_indices": 1} super().__init__(pooler=pooler, **kwargs) self.multi_label_threshold = label_threshold - def setup_classifier(self, pooler_output_dim: int) -> Callable: + def setup_classifier( + self, pooler_output_dim: int + ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: return torch.nn.functional.cosine_similarity - def setup_loss_fct(self): + def setup_loss_fct(self) -> Callable: return nn.BCELoss() def forward( diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py index 2c8faefeb..6bd48b6bb 100644 --- a/tests/models/test_sequence_pair_similarity_model_with_pooler.py +++ b/tests/models/test_sequence_pair_similarity_model_with_pooler.py @@ -86,6 +86,7 @@ def model() -> SequencePairSimilarityModelWithPooler: torch.manual_seed(42) result = SequencePairSimilarityModelWithPooler( model_name_or_path="prajjwal1/bert-tiny", + label_threshold=0.5, ) return result From 309346a049456b8cde146d29da941f3201abb60a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 21:02:07 +0200 Subject: [PATCH 21/49] fix missed index shift because of added special tokens --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index c9ef53b61..8235389af 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -101,6 +101,11 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() + self.num_special_tokens_before = len(self._get_special_tokens_before()) + + def _get_special_tokens_before(self) -> List[int]: + dummy_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=[-1]) + return dummy_ids[: dummy_ids.index(-1)] def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable[DocumentType]: positive_tuples = defaultdict(set) @@ -183,6 +188,9 @@ def truncate_encoding_around_span( token_end -= window_start truncated_encoding = self.tokenizer.prepare_for_model(ids=input_ids) + # shift indices because we added special tokens to the input_ids + token_start += self.num_special_tokens_before + token_end += self.num_special_tokens_before return truncated_encoding, Span(start=token_start, end=token_end) From 8681f2494b297909d78a3f124422977c1e5b0558 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 21:05:28 +0200 Subject: [PATCH 22/49] minor fixes --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 8235389af..9f196317b 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -18,7 +18,6 @@ from pytorch_ie import Annotation from pytorch_ie.annotations import Span from pytorch_ie.core import TaskEncoding, TaskModule -from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize from pytorch_ie.utils.window import get_window_around_slice from torchmetrics import Metric, MetricCollection from torchmetrics.classification import BinaryAUROC @@ -81,9 +80,7 @@ def _get_labels(model_output: ModelTargetType) -> torch.Tensor: @TaskModule.register() -class CrossTextBinaryCorefTaskModule( - RelationStatisticsMixin, TaskModuleType, ChangesTokenizerVocabSize -): +class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): DOCUMENT_TYPE = DocumentType def __init__( @@ -97,13 +94,12 @@ def __init__( self.save_hyperparameters() self.add_negative_relations = add_negative_relations - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() - self.num_special_tokens_before = len(self._get_special_tokens_before()) + self.num_special_tokens_before = len(self._get_special_tokens_before_input()) - def _get_special_tokens_before(self) -> List[int]: + def _get_special_tokens_before_input(self) -> List[int]: dummy_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=[-1]) return dummy_ids[: dummy_ids.index(-1)] From 4376d16f3928068663253288379ac45e041198e2 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 21:15:06 +0200 Subject: [PATCH 23/49] add check for direction of coref relations (should point from text to text_pair) --- .../taskmodules/cross_text_binary_coref.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 9f196317b..f0d6368d3 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -208,6 +208,18 @@ def encode_input( task_encodings = [] for coref_rel in document.binary_coref_relations: + # TODO: This can miss instances if both texts are the same. We could check that + # coref_rel.head is in document.labeled_spans (same for the tail), but would this + # slow down the encoding? + if not ( + coref_rel.head.target == document.text + or coref_rel.tail.target == document.text_pair + ): + raise ValueError( + f"It is expected that coref relations go from (head) spans over 'text' " + f"to (tail) spans over 'text_pair', but this is not the case for this " + f"relation (i.e. it points into the other direction): {coref_rel.resolve()}" + ) try: current_encoding, token_span = self.truncate_encoding_around_span( encoding=encoding, char_span=coref_rel.head From ebd09b2f7f2cba52524c4017a5e3e73f46184a78 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 21:24:23 +0200 Subject: [PATCH 24/49] add documentation for SequencePairSimilarityModelWithPooler --- .../models/sequence_classification_with_pooler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index b72c7fd2d..be9e39f6c 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -272,11 +272,15 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: class SequencePairSimilarityModelWithPooler( SequenceClassificationModelWithPoolerBase, ): - """TODO. + """A span pair similarity model to detect of two spans occurring in different texts are + similar. It uses an encoder to independently calculate contextualized embeddings of both texts, + then uses a pooler to get representations of the spans and, finally, calculates the cosine to + get the similarity scores. Args: label_threshold: The threshold above which score the spans are considered as similar. - pooler: + pooler: The pooler identifier or config, see :func:`get_pooler_and_output_size` for details. + Defaults to "mention_pooling" (max pooling over the span token embeddings). **kwargs """ From 98cc08f3212bf032f2d9f035531cc55f6e5ef531 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 21:27:42 +0200 Subject: [PATCH 25/49] add model and taskmodule to readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d6576146f..dac0027c0 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Available models: - [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py) - [SequenceClassificationModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py) +- [SequencePairSimilarityModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py) - [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py) - [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py) - [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py) @@ -25,6 +26,7 @@ Available models: Available taskmodules: - [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py) +- [CrossTextBinaryCorefTaskModule](src/pie_modules/taskmodules/cross_text_binary_coref.py) - [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py) - [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py) - [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py) From 9599ee2476003e5727eabf05fefbc9c287e54a71 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 12 Sep 2024 21:32:37 +0200 Subject: [PATCH 26/49] add short documenation for CrossTextBinaryCorefTaskModule --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index f0d6368d3..a4bbf2e6e 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -81,6 +81,10 @@ def _get_labels(model_output: ModelTargetType) -> torch.Tensor: @TaskModule.register() class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): + """This taskmodule processes documents of type + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a + SequencePairSimilarityModelWithPooler.""" + DOCUMENT_TYPE = DocumentType def __init__( From 0cb17087272afdb7357c279a9a863e6683c88317 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 13 Sep 2024 13:25:14 +0200 Subject: [PATCH 27/49] outsource add_negative_relations() to document.precessing.text_pair --- .../document/processing/text_pair.py | 48 +++++++++++++++++ .../taskmodules/cross_text_binary_coref.py | 51 +------------------ .../test_cross_text_binary_coref.py | 30 +++-------- 3 files changed, 58 insertions(+), 71 deletions(-) create mode 100644 src/pie_modules/document/processing/text_pair.py diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py new file mode 100644 index 000000000..642293136 --- /dev/null +++ b/src/pie_modules/document/processing/text_pair.py @@ -0,0 +1,48 @@ +from collections import defaultdict +from typing import Iterable + +from pie_modules.document.types import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) + + +def add_negative_relations( + documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs +) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: + positive_tuples = defaultdict(set) + text2spans = defaultdict(set) + for doc in documents: + for labeled_span in doc.labeled_spans: + text2spans[doc.text].add(labeled_span.copy()) + for labeled_span in doc.labeled_spans_pair: + text2spans[doc.text_pair].add(labeled_span.copy()) + + for coref in doc.binary_coref_relations: + positive_tuples[(doc.text, doc.text_pair)].add((coref.head.copy(), coref.tail.copy())) + positive_tuples[(doc.text_pair, doc.text)].add((coref.tail.copy(), coref.head.copy())) + + new_docs = [] + for text in sorted(text2spans): + for text_pair in sorted(text2spans): + current_positives = positive_tuples.get((text, text_pair), set()) + new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + text=text, text_pair=text_pair + ) + new_doc.labeled_spans.extend(labeled_span.copy() for labeled_span in text2spans[text]) + new_doc.labeled_spans_pair.extend( + labeled_span.copy() for labeled_span in text2spans[text_pair] + ) + for s in sorted(new_doc.labeled_spans): + for s_p in sorted(new_doc.labeled_spans_pair): + # exclude relations to itself + if text == text_pair and s.copy() == s_p.copy(): + continue + if s.label != s_p.label: + continue + score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 + new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) + new_doc.binary_coref_relations.append(new_coref_rel) + new_docs.append(new_doc) + + return new_docs diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index a4bbf2e6e..d737fef12 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -24,6 +24,7 @@ from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias +from pie_modules.document.processing.text_pair import add_negative_relations from pie_modules.document.types import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, @@ -90,14 +91,12 @@ class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): def __init__( self, tokenizer_name_or_path: str, - add_negative_relations: bool = False, max_window: Optional[int] = None, **kwargs, ) -> None: super().__init__(**kwargs) self.save_hyperparameters() - self.add_negative_relations = add_negative_relations self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() @@ -108,56 +107,10 @@ def _get_special_tokens_before_input(self) -> List[int]: return dummy_ids[: dummy_ids.index(-1)] def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable[DocumentType]: - positive_tuples = defaultdict(set) - text2spans = defaultdict(set) - for doc in positives: - for labeled_span in doc.labeled_spans: - text2spans[doc.text].add(labeled_span.copy()) - for labeled_span in doc.labeled_spans_pair: - text2spans[doc.text_pair].add(labeled_span.copy()) - - for coref in doc.binary_coref_relations: - positive_tuples[(doc.text, doc.text_pair)].add( - (coref.head.copy(), coref.tail.copy()) - ) - positive_tuples[(doc.text_pair, doc.text)].add( - (coref.tail.copy(), coref.head.copy()) - ) - - new_docs = [] - for text in sorted(text2spans): - for text_pair in sorted(text2spans): - current_positives = positive_tuples.get((text, text_pair), set()) - new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( - text=text, text_pair=text_pair - ) - new_doc.labeled_spans.extend( - labeled_span.copy() for labeled_span in text2spans[text] - ) - new_doc.labeled_spans_pair.extend( - labeled_span.copy() for labeled_span in text2spans[text_pair] - ) - for s in sorted(new_doc.labeled_spans): - for s_p in sorted(new_doc.labeled_spans_pair): - # exclude relations to itself - if text == text_pair and s.copy() == s_p.copy(): - continue - if s.label != s_p.label: - continue - score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 - new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) - new_doc.binary_coref_relations.append(new_coref_rel) - new_docs.append(new_doc) - - return new_docs + return add_negative_relations(documents=positives) def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): self.reset_statistics() - if self.add_negative_relations: - if isinstance(documents, DocumentType): - documents = [documents] - documents = self._add_negative_relations(documents) - result = super().encode(documents=documents, **kwargs) self.show_statistics() return result diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 252c88b7c..55e2f1c50 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -7,6 +7,7 @@ from pytorch_ie.annotations import LabeledSpan from torchmetrics import Metric, MetricCollection +from pie_modules.document.processing.text_pair import add_negative_relations from pie_modules.document.types import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, @@ -19,7 +20,6 @@ CONFIGS = [ {}, - # {"add_negative_relations": True}, ] CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} @@ -87,9 +87,9 @@ def taskmodule(unprepared_taskmodule, positive_documents): return unprepared_taskmodule -def test_construct_negative_documents(taskmodule, positive_documents): +def test_construct_negative_documents(positive_documents): assert len(positive_documents) == 2 - docs = list(taskmodule._add_negative_relations(positive_documents)) + docs = list(add_negative_relations(positive_documents)) TEXTS = [ "Entity A works at B.", "And she founded C.", @@ -164,7 +164,7 @@ def documents_with_negatives(taskmodule, positive_documents): FIXTURES_ROOT / "taskmodules" / "cross_text_binary_coref" / "documents_with_negatives.json" ) - # result = list(taskmodule._add_negative_relations(positive_documents)) + # result = list(add_negative_relations(positive_documents)) # result_json = [doc.asdict() for doc in result] # with open(file_name, "w") as f: # json.dump(result_json, f, indent=2) @@ -231,28 +231,14 @@ def test_encode_target(task_encodings_without_target, taskmodule): assert target == 0.0 -def test_encode_with_add_negative_relations(taskmodule, positive_documents): - original_value = taskmodule.add_negative_relations - taskmodule.add_negative_relations = False - documents_with_negatives = list(taskmodule._add_negative_relations(positive_documents)) - task_encodings1 = taskmodule.encode(documents_with_negatives, encode_target=True) - taskmodule.add_negative_relations = True - task_encodings2 = taskmodule.encode(positive_documents, encode_target=True) - taskmodule.add_negative_relations = original_value - - for task_encoding1, task_encoding2 in zip(task_encodings1, task_encodings2): - torch.testing.assert_close(task_encoding1.inputs, task_encoding2.inputs) - torch.testing.assert_close(task_encoding1.targets, task_encoding2.targets) - - def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): + documents_with_negatives = add_negative_relations(positive_documents) caplog.clear() with caplog.at_level(logging.INFO): - original_values = taskmodule.add_negative_relations, taskmodule.collect_statistics - taskmodule.add_negative_relations = True + original_values = taskmodule.collect_statistics taskmodule.collect_statistics = True - taskmodule.encode(positive_documents, encode_target=True) - taskmodule.add_negative_relations, taskmodule.collect_statistics = original_values + taskmodule.encode(documents_with_negatives, encode_target=True) + taskmodule.collect_statistics = original_values assert len(caplog.messages) == 1 assert ( From 6b0c032de4826147a53641a891c317b06281b4de Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 13 Sep 2024 13:45:09 +0200 Subject: [PATCH 28/49] update documents_with_negatives.json with current output --- .../documents_with_negatives.json | 682 +++++++++--------- .../test_cross_text_binary_coref.py | 83 +-- 2 files changed, 370 insertions(+), 395 deletions(-) diff --git a/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json b/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json index 0e0b4062f..cbc258ecd 100644 --- a/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json +++ b/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json @@ -1,76 +1,103 @@ [ { - "text_pair": "Bob loves his cat.", + "text_pair": "And she founded C.", "text": "And she founded C.", "id": null, "metadata": null, "labeled_spans": { "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, { "start": 4, "end": 7, "label": "PERSON", "score": 1.0, - "_id": -5246751469876588720 - }, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ { "start": 16, "end": 17, "label": "COMPANY", "score": 1.0, - "_id": 3043206444225553475 + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 } ], "predictions": [] }, - "labeled_spans_pair": { + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { "annotations": [ { - "start": 10, + "start": 16, "end": 17, - "label": "ANIMAL", + "label": "COMPANY", "score": 1.0, - "_id": 5373078146820384347 + "_id": -2143209897469179365 }, { - "start": 0, - "end": 3, + "start": 4, + "end": 7, "label": "PERSON", "score": 1.0, - "_id": -3679976720952382748 + "_id": 2545181322977893893 } ], "predictions": [] }, - "binary_coref_relations": { + "labeled_spans_pair": { "annotations": [ { - "head": -5246751469876588720, - "tail": -3679976720952382748, - "label": "coref", - "score": 0.0, - "_id": -1226852003320818417 - }, - { - "head": -5246751469876588720, - "tail": 5373078146820384347, - "label": "coref", - "score": 0.0, - "_id": -2897381892745677680 + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 }, { - "head": 3043206444225553475, - "tail": -3679976720952382748, - "label": "coref", - "score": 0.0, - "_id": 4747715004687052922 - }, + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ { - "head": 3043206444225553475, - "tail": 5373078146820384347, + "head": 2545181322977893893, + "tail": -7091027580690283656, "label": "coref", "score": 0.0, - "_id": 8355440541443623552 + "_id": -1763877672186772918 } ], "predictions": [] @@ -83,38 +110,38 @@ "metadata": null, "labeled_spans": { "annotations": [ - { - "start": 4, - "end": 7, - "label": "PERSON", - "score": 1.0, - "_id": -5246751469876588720 - }, { "start": 16, "end": 17, "label": "COMPANY", "score": 1.0, - "_id": 3043206444225553475 + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 0, - "end": 8, - "label": "PERSON", - "score": 1.0, - "_id": 3233654095506762724 - }, { "start": 18, "end": 19, "label": "COMPANY", "score": 1.0, - "_id": -2183238448703307780 + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 } ], "predictions": [] @@ -122,32 +149,18 @@ "binary_coref_relations": { "annotations": [ { - "head": -5246751469876588720, - "tail": 3233654095506762724, + "head": 2545181322977893893, + "tail": -177396764231138184, "label": "coref", "score": 1.0, - "_id": -4357456139038854264 - }, - { - "head": -5246751469876588720, - "tail": -2183238448703307780, - "label": "coref", - "score": 0.0, - "_id": 466813473723110234 - }, - { - "head": 3043206444225553475, - "tail": 3233654095506762724, - "label": "coref", - "score": 0.0, - "_id": -4272399218893089512 + "_id": 5113198133391321397 }, { - "head": 3043206444225553475, - "tail": -2183238448703307780, + "head": -2143209897469179365, + "tail": 3188240167591245379, "label": "coref", "score": 0.0, - "_id": -5602326156476594098 + "_id": -734219494647036300 } ], "predictions": [] @@ -160,19 +173,19 @@ "metadata": null, "labeled_spans": { "annotations": [ - { - "start": 4, - "end": 7, - "label": "PERSON", - "score": 1.0, - "_id": -5246751469876588720 - }, { "start": 16, "end": 17, "label": "COMPANY", "score": 1.0, - "_id": 3043206444225553475 + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 } ], "predictions": [] @@ -184,28 +197,13 @@ "end": 3, "label": "ANIMAL", "score": 1.0, - "_id": -190677143789164847 + "_id": 2360667792531975882 } ], "predictions": [] }, "binary_coref_relations": { - "annotations": [ - { - "head": -5246751469876588720, - "tail": -190677143789164847, - "label": "coref", - "score": 0.0, - "_id": -8510476152511400278 - }, - { - "head": 3043206444225553475, - "tail": -190677143789164847, - "label": "coref", - "score": 0.0, - "_id": 3867128290998447006 - } - ], + "annotations": [], "predictions": [] } }, @@ -216,38 +214,38 @@ "metadata": null, "labeled_spans": { "annotations": [ - { - "start": 10, - "end": 17, - "label": "ANIMAL", - "score": 1.0, - "_id": 5373078146820384347 - }, { "start": 0, "end": 3, "label": "PERSON", "score": 1.0, - "_id": -3679976720952382748 + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 4, - "end": 7, - "label": "PERSON", - "score": 1.0, - "_id": -5246751469876588720 - }, { "start": 16, "end": 17, "label": "COMPANY", "score": 1.0, - "_id": 3043206444225553475 + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 } ], "predictions": [] @@ -255,76 +253,103 @@ "binary_coref_relations": { "annotations": [ { - "head": -3679976720952382748, - "tail": -5246751469876588720, - "label": "coref", - "score": 0.0, - "_id": -2767191573101294319 - }, - { - "head": -3679976720952382748, - "tail": 3043206444225553475, + "head": -7091027580690283656, + "tail": 2545181322977893893, "label": "coref", "score": 0.0, - "_id": -4437612686117351921 - }, - { - "head": 5373078146820384347, - "tail": -5246751469876588720, - "label": "coref", - "score": 0.0, - "_id": 5020739476238125539 - }, - { - "head": 5373078146820384347, - "tail": 3043206444225553475, - "label": "coref", - "score": 0.0, - "_id": 3902122025380513665 + "_id": 4323963091729289163 } ], "predictions": [] } }, { - "text_pair": "Entity A works at B.", + "text_pair": "Bob loves his cat.", "text": "Bob loves his cat.", "id": null, "metadata": null, "labeled_spans": { "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, { "start": 10, "end": 17, "label": "ANIMAL", "score": 1.0, - "_id": 5373078146820384347 - }, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ { "start": 0, "end": 3, "label": "PERSON", "score": 1.0, - "_id": -3679976720952382748 + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 } ], "predictions": [] }, - "labeled_spans_pair": { + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { "annotations": [ { "start": 0, - "end": 8, + "end": 3, "label": "PERSON", "score": 1.0, - "_id": 3233654095506762724 + "_id": -7091027580690283656 }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ { "start": 18, "end": 19, "label": "COMPANY", "score": 1.0, - "_id": -2183238448703307780 + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 } ], "predictions": [] @@ -332,32 +357,11 @@ "binary_coref_relations": { "annotations": [ { - "head": -3679976720952382748, - "tail": 3233654095506762724, + "head": -7091027580690283656, + "tail": -177396764231138184, "label": "coref", "score": 0.0, - "_id": -8901055438221583123 - }, - { - "head": -3679976720952382748, - "tail": -2183238448703307780, - "label": "coref", - "score": 0.0, - "_id": -8898764560981633135 - }, - { - "head": 5373078146820384347, - "tail": 3233654095506762724, - "label": "coref", - "score": 0.0, - "_id": -1347933737476127508 - }, - { - "head": 5373078146820384347, - "tail": -2183238448703307780, - "label": "coref", - "score": 0.0, - "_id": 3930515724475035731 + "_id": -4269111567075058761 } ], "predictions": [] @@ -370,19 +374,19 @@ "metadata": null, "labeled_spans": { "annotations": [ - { - "start": 10, - "end": 17, - "label": "ANIMAL", - "score": 1.0, - "_id": 5373078146820384347 - }, { "start": 0, "end": 3, "label": "PERSON", "score": 1.0, - "_id": -3679976720952382748 + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 } ], "predictions": [] @@ -394,7 +398,7 @@ "end": 3, "label": "ANIMAL", "score": 1.0, - "_id": -190677143789164847 + "_id": 2360667792531975882 } ], "predictions": [] @@ -402,18 +406,11 @@ "binary_coref_relations": { "annotations": [ { - "head": -3679976720952382748, - "tail": -190677143789164847, - "label": "coref", - "score": 0.0, - "_id": -5159885311314414733 - }, - { - "head": 5373078146820384347, - "tail": -190677143789164847, + "head": -6613361595321704194, + "tail": 2360667792531975882, "label": "coref", "score": 1.0, - "_id": -4858368627143918533 + "_id": 8198921634551745514 } ], "predictions": [] @@ -426,38 +423,38 @@ "metadata": null, "labeled_spans": { "annotations": [ - { - "start": 0, - "end": 8, - "label": "PERSON", - "score": 1.0, - "_id": 3233654095506762724 - }, { "start": 18, "end": 19, "label": "COMPANY", "score": 1.0, - "_id": -2183238448703307780 + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 4, - "end": 7, - "label": "PERSON", - "score": 1.0, - "_id": -5246751469876588720 - }, { "start": 16, "end": 17, "label": "COMPANY", "score": 1.0, - "_id": 3043206444225553475 + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 } ], "predictions": [] @@ -465,32 +462,18 @@ "binary_coref_relations": { "annotations": [ { - "head": 3233654095506762724, - "tail": -5246751469876588720, + "head": -177396764231138184, + "tail": 2545181322977893893, "label": "coref", "score": 1.0, - "_id": 2444090963512005184 + "_id": -4710872194864906092 }, { - "head": 3233654095506762724, - "tail": 3043206444225553475, + "head": 3188240167591245379, + "tail": -2143209897469179365, "label": "coref", "score": 0.0, - "_id": -7963340116969175614 - }, - { - "head": -2183238448703307780, - "tail": -5246751469876588720, - "label": "coref", - "score": 0.0, - "_id": -9120191367688252721 - }, - { - "head": -2183238448703307780, - "tail": 3043206444225553475, - "label": "coref", - "score": 0.0, - "_id": 7975222748039939420 + "_id": 2636939255468582059 } ], "predictions": [] @@ -503,38 +486,38 @@ "metadata": null, "labeled_spans": { "annotations": [ - { - "start": 0, - "end": 8, - "label": "PERSON", - "score": 1.0, - "_id": 3233654095506762724 - }, { "start": 18, "end": 19, "label": "COMPANY", "score": 1.0, - "_id": -2183238448703307780 + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 10, - "end": 17, - "label": "ANIMAL", - "score": 1.0, - "_id": 5373078146820384347 - }, { "start": 0, "end": 3, "label": "PERSON", "score": 1.0, - "_id": -3679976720952382748 + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 } ], "predictions": [] @@ -542,57 +525,84 @@ "binary_coref_relations": { "annotations": [ { - "head": 3233654095506762724, - "tail": -3679976720952382748, - "label": "coref", - "score": 0.0, - "_id": 1280608060947850168 - }, - { - "head": 3233654095506762724, - "tail": 5373078146820384347, - "label": "coref", - "score": 0.0, - "_id": -3000515518015844819 - }, - { - "head": -2183238448703307780, - "tail": -3679976720952382748, - "label": "coref", - "score": 0.0, - "_id": -4464070305304755517 - }, - { - "head": -2183238448703307780, - "tail": 5373078146820384347, + "head": -177396764231138184, + "tail": -7091027580690283656, "label": "coref", "score": 0.0, - "_id": 3298512753939125167 + "_id": -1990964066152094896 } ], "predictions": [] } }, { - "text_pair": "She sleeps a lot.", + "text_pair": "Entity A works at B.", "text": "Entity A works at B.", "id": null, "metadata": null, "labeled_spans": { "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, { "start": 0, "end": 8, "label": "PERSON", "score": 1.0, - "_id": 3233654095506762724 + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ { "start": 18, "end": 19, "label": "COMPANY", "score": 1.0, - "_id": -2183238448703307780 + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 } ], "predictions": [] @@ -604,28 +614,13 @@ "end": 3, "label": "ANIMAL", "score": 1.0, - "_id": -190677143789164847 + "_id": 2360667792531975882 } ], "predictions": [] }, "binary_coref_relations": { - "annotations": [ - { - "head": 3233654095506762724, - "tail": -190677143789164847, - "label": "coref", - "score": 0.0, - "_id": -3444435532096461506 - }, - { - "head": -2183238448703307780, - "tail": -190677143789164847, - "label": "coref", - "score": 0.0, - "_id": -3912955313637853940 - } - ], + "annotations": [], "predictions": [] } }, @@ -641,47 +636,32 @@ "end": 3, "label": "ANIMAL", "score": 1.0, - "_id": -190677143789164847 + "_id": 2360667792531975882 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 4, - "end": 7, - "label": "PERSON", - "score": 1.0, - "_id": -5246751469876588720 - }, { "start": 16, "end": 17, "label": "COMPANY", "score": 1.0, - "_id": 3043206444225553475 + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 } ], "predictions": [] }, "binary_coref_relations": { - "annotations": [ - { - "head": -190677143789164847, - "tail": -5246751469876588720, - "label": "coref", - "score": 0.0, - "_id": -6992824161873864749 - }, - { - "head": -190677143789164847, - "tail": 3043206444225553475, - "label": "coref", - "score": 0.0, - "_id": 6180444938490764939 - } - ], + "annotations": [], "predictions": [] } }, @@ -697,26 +677,26 @@ "end": 3, "label": "ANIMAL", "score": 1.0, - "_id": -190677143789164847 + "_id": 2360667792531975882 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 10, - "end": 17, - "label": "ANIMAL", - "score": 1.0, - "_id": 5373078146820384347 - }, { "start": 0, "end": 3, "label": "PERSON", "score": 1.0, - "_id": -3679976720952382748 + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 } ], "predictions": [] @@ -724,18 +704,11 @@ "binary_coref_relations": { "annotations": [ { - "head": -190677143789164847, - "tail": -3679976720952382748, - "label": "coref", - "score": 0.0, - "_id": 2061654283494000583 - }, - { - "head": -190677143789164847, - "tail": 5373078146820384347, + "head": 2360667792531975882, + "tail": -6613361595321704194, "label": "coref", "score": 1.0, - "_id": -4650461605955518398 + "_id": -571410837328299027 } ], "predictions": [] @@ -753,48 +726,67 @@ "end": 3, "label": "ANIMAL", "score": 1.0, - "_id": -190677143789164847 + "_id": 2360667792531975882 } ], "predictions": [] }, "labeled_spans_pair": { "annotations": [ - { - "start": 0, - "end": 8, - "label": "PERSON", - "score": 1.0, - "_id": 3233654095506762724 - }, { "start": 18, "end": 19, "label": "COMPANY", "score": 1.0, - "_id": -2183238448703307780 + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 } ], "predictions": [] }, "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { "annotations": [ { - "head": -190677143789164847, - "tail": 3233654095506762724, - "label": "coref", - "score": 0.0, - "_id": 8092666078797453961 - }, + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ { - "head": -190677143789164847, - "tail": -2183238448703307780, - "label": "coref", - "score": 0.0, - "_id": -5075628532960934416 + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 } ], "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] } } ] diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 55e2f1c50..939458bf1 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -17,6 +17,7 @@ from tests import FIXTURES_ROOT, _config_to_str TOKENIZER_NAME_OR_PATH = "bert-base-cased" +DOC_IDX_WITH_TASK_ENCODINGS = 2 CONFIGS = [ {}, @@ -181,7 +182,7 @@ def documents_with_negatives(taskmodule, positive_documents): @pytest.fixture(scope="module") def task_encodings_without_target(taskmodule, documents_with_negatives): - task_encodings = taskmodule.encode_input(documents_with_negatives[0]) + task_encodings = taskmodule.encode_input(documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS]) return task_encodings @@ -199,14 +200,10 @@ def test_encode_input(task_encodings_without_target, taskmodule): assert tokens == [ ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], - ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], - ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], ] assert tokens_pair == [ - ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], - ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], - ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], - ["[CLS]", "Bob", "loves", "his", "cat", ".", "[SEP]"], + ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], + ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], ] span_tokens = [ toks[start:end] @@ -222,13 +219,15 @@ def test_encode_input(task_encodings_without_target, taskmodule): inputs_dict["pooler_pair_end_indices"], ) ] - assert span_tokens == [["she"], ["she"], ["C"], ["C"]] - assert span_tokens_pair == [["Bob"], ["his", "cat"], ["Bob"], ["his", "cat"]] + assert span_tokens == [["she"], ["C"]] + assert span_tokens_pair == [["En", "##ti", "##ty", "A"], ["B"]] def test_encode_target(task_encodings_without_target, taskmodule): - target = taskmodule.encode_target(task_encodings_without_target[0]) - assert target == 0.0 + targets = [ + taskmodule.encode_target(task_encoding) for task_encoding in task_encodings_without_target + ] + assert targets == [1.0, 0.0] def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): @@ -261,7 +260,7 @@ def test_encode_with_windowing(documents_with_negatives, caplog): assert not taskmodule.is_from_pretrained taskmodule.prepare(documents_with_negatives) - assert len(documents_with_negatives) == 12 + assert len(documents_with_negatives) == 16 caplog.clear() with caplog.at_level(logging.INFO): task_encodings = taskmodule.encode(documents_with_negatives) @@ -270,13 +269,13 @@ def test_encode_with_windowing(documents_with_negatives, caplog): caplog.messages[-1] == "statistics:\n" "| | coref | no_relation | all_relations |\n" "|:--------------------------------------|--------:|--------------:|----------------:|\n" - "| available | 4 | 32 | 4 |\n" - "| skipped_span_does_not_fit_into_window | 2 | 8 | 2 |\n" - "| used | 2 | 24 | 2 |\n" - "| used % | 50 | 75 | 50 |" + "| available | 4 | 6 | 4 |\n" + "| skipped_span_does_not_fit_into_window | 2 | 2 | 2 |\n" + "| used | 2 | 4 | 2 |\n" + "| used % | 50 | 67 | 50 |" ) - assert len(task_encodings) == 26 + assert len(task_encodings) == 6 for task_encoding in task_encodings: for k, v in task_encoding.inputs["encoding"].items(): assert len(v) <= taskmodule.max_window @@ -286,7 +285,9 @@ def test_encode_with_windowing(documents_with_negatives, caplog): @pytest.fixture(scope="module") def task_encodings(taskmodule, documents_with_negatives): - return taskmodule.encode(documents_with_negatives[0], encode_target=True) + return taskmodule.encode( + documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS], encode_target=True + ) @pytest.fixture(scope="module") @@ -310,12 +311,7 @@ def test_collate(batch, taskmodule): torch.testing.assert_close( inputs["encoding"]["input_ids"], torch.tensor( - [ - [101, 1262, 1131, 1771, 140, 119, 102], - [101, 1262, 1131, 1771, 140, 119, 102], - [101, 1262, 1131, 1771, 140, 119, 102], - [101, 1262, 1131, 1771, 140, 119, 102], - ] + [[101, 1262, 1131, 1771, 140, 119, 102], [101, 1262, 1131, 1771, 140, 119, 102]] ), ) torch.testing.assert_close( @@ -329,10 +325,8 @@ def test_collate(batch, taskmodule): inputs["encoding_pair"]["input_ids"], torch.tensor( [ - [101, 3162, 7871, 1117, 5855, 119, 102], - [101, 3162, 7871, 1117, 5855, 119, 102], - [101, 3162, 7871, 1117, 5855, 119, 102], - [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102], + [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102], ] ), ) @@ -345,36 +339,28 @@ def test_collate(batch, taskmodule): torch.ones_like(inputs["encoding_pair"]["input_ids"]), ) - torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [2], [4], [4]])) - torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [3], [5], [5]])) - torch.testing.assert_close( - inputs["pooler_pair_start_indices"], torch.tensor([[1], [3], [1], [3]]) - ) - torch.testing.assert_close( - inputs["pooler_pair_end_indices"], torch.tensor([[2], [5], [2], [5]]) - ) + torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [4]])) + torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [5]])) + torch.testing.assert_close(inputs["pooler_pair_start_indices"], torch.tensor([[1], [7]])) + torch.testing.assert_close(inputs["pooler_pair_end_indices"], torch.tensor([[5], [8]])) - torch.testing.assert_close(targets, {"labels": torch.tensor([0.0, 0.0, 0.0, 0.0])}) + torch.testing.assert_close(targets, {"labels": torch.tensor([1.0, 0.0])}) @pytest.fixture(scope="module") def unbatched_output(taskmodule): model_output = { - "labels": torch.tensor([1, 1, 1, 1]), - "probabilities": torch.tensor( - [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] - ), + "labels": torch.tensor([0, 1]), + "probabilities": torch.tensor([0.5338148474693298, 0.9866107940673828]), } return taskmodule.unbatch_output(model_output=model_output) def test_unbatch_output(unbatched_output, taskmodule): - assert len(unbatched_output) == 4 + assert len(unbatched_output) == 2 assert unbatched_output == [ - {"is_valid": True, "score": 0.5338148474693298}, - {"is_valid": True, "score": 0.5866107940673828}, - {"is_valid": True, "score": 0.5076886415481567}, - {"is_valid": True, "score": 0.5946245789527893}, + {"is_valid": False, "score": 0.5338148474693298}, + {"is_valid": True, "score": 0.9866107702255249}, ] @@ -390,10 +376,7 @@ def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_out (round(ann.score, 4), ann.resolve()) for layer_name, ann in all_new_annotations ] assert resolve_annotations_with_scores == [ - (0.5338, ("coref", (("PERSON", "she"), ("PERSON", "Bob")))), - (0.5866, ("coref", (("PERSON", "she"), ("ANIMAL", "his cat")))), - (0.5077, ("coref", (("COMPANY", "C"), ("PERSON", "Bob")))), - (0.5946, ("coref", (("COMPANY", "C"), ("ANIMAL", "his cat")))), + (0.9866, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), ] From 1901779e23415a4645fce3ce94672738768b08c5 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 13 Sep 2024 16:44:33 +0200 Subject: [PATCH 29/49] rename add_negative_relations() to add_negative_coref_relations() and remove from CrossTextBinaryCorefTaskModule --- src/pie_modules/document/processing/text_pair.py | 2 +- src/pie_modules/taskmodules/cross_text_binary_coref.py | 6 ------ tests/taskmodules/test_cross_text_binary_coref.py | 6 +++--- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 642293136..0275cdf72 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -7,7 +7,7 @@ ) -def add_negative_relations( +def add_negative_coref_relations( documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs ) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: positive_tuples = defaultdict(set) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index d737fef12..5a8bfba4f 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -1,6 +1,5 @@ import copy import logging -from collections import defaultdict from typing import ( Any, Dict, @@ -24,9 +23,7 @@ from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias -from pie_modules.document.processing.text_pair import add_negative_relations from pie_modules.document.types import ( - BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin @@ -106,9 +103,6 @@ def _get_special_tokens_before_input(self) -> List[int]: dummy_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=[-1]) return dummy_ids[: dummy_ids.index(-1)] - def _add_negative_relations(self, positives: Iterable[DocumentType]) -> Iterable[DocumentType]: - return add_negative_relations(documents=positives) - def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): self.reset_statistics() result = super().encode(documents=documents, **kwargs) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 939458bf1..63425c2e0 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -7,7 +7,7 @@ from pytorch_ie.annotations import LabeledSpan from torchmetrics import Metric, MetricCollection -from pie_modules.document.processing.text_pair import add_negative_relations +from pie_modules.document.processing.text_pair import add_negative_coref_relations from pie_modules.document.types import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, @@ -90,7 +90,7 @@ def taskmodule(unprepared_taskmodule, positive_documents): def test_construct_negative_documents(positive_documents): assert len(positive_documents) == 2 - docs = list(add_negative_relations(positive_documents)) + docs = list(add_negative_coref_relations(positive_documents)) TEXTS = [ "Entity A works at B.", "And she founded C.", @@ -231,7 +231,7 @@ def test_encode_target(task_encodings_without_target, taskmodule): def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): - documents_with_negatives = add_negative_relations(positive_documents) + documents_with_negatives = add_negative_coref_relations(positive_documents) caplog.clear() with caplog.at_level(logging.INFO): original_values = taskmodule.collect_statistics From 8093a58c3a2cdef7e9605c6cb5f459055b82e210 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 13 Sep 2024 17:51:38 +0200 Subject: [PATCH 30/49] implement construct_text_pair_coref_documents_from_partitions_via_relations and add dedicated test_text_pair.py --- .../document/processing/text_pair.py | 113 +++++++- tests/document/processing/test_text_pair.py | 267 ++++++++++++++++++ .../test_cross_text_binary_coref.py | 71 ----- 3 files changed, 379 insertions(+), 72 deletions(-) create mode 100644 tests/document/processing/test_text_pair.py diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 0275cdf72..e0e4fc446 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,10 +1,121 @@ from collections import defaultdict -from typing import Iterable +from collections.abc import Iterator +from typing import Dict, Iterable, List, Tuple, TypeVar + +from pytorch_ie.annotations import LabeledSpan, Span +from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) from pie_modules.document.types import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) +from pie_modules.utils.span import are_nested + +S = TypeVar("S", bound=Span) +S2 = TypeVar("S2", bound=Span) + + +def _span2partition_mapping(spans: Iterable[S], partitions: Iterable[S2]) -> Dict[S, S2]: + result = {} + for span in spans: + for partition in partitions: + if are_nested( + start_end=(span.start, span.end), other_start_end=(partition.start, partition.end) + ): + result[span] = partition + break + return result + + +def _span_copy_shifted(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +def _construct_text_pair_coref_documents_from_partitions_via_relations( + document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, relation_label: str +) -> List[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: + span2partition = _span2partition_mapping( + spans=document.labeled_spans, partitions=document.labeled_partitions + ) + partition2spans = defaultdict(list) + for span, partition in span2partition.items(): + partition2spans[partition].append(span) + + texts2docs_and_span_mappings: Dict[ + Tuple[str, str], + Tuple[ + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, + Dict[LabeledSpan, LabeledSpan], + Dict[LabeledSpan, LabeledSpan], + ], + ] = dict() + result = [] + for rel in document.binary_relations: + if rel.label != relation_label: + continue + + if rel.head not in span2partition: + raise ValueError(f"head not in any partition: {rel.head}") + head_partition = span2partition[rel.head] + text = document.text[head_partition.start : head_partition.end] + + if rel.tail not in span2partition: + raise ValueError(f"tail not in any partition: {rel.tail}") + tail_partition = span2partition[rel.tail] + text_pair = document.text[tail_partition.start : tail_partition.end] + + if (text, text_pair) in texts2docs_and_span_mappings: + new_doc, head_spans_mapping, tail_spans_mapping = texts2docs_and_span_mappings[ + (text, text_pair) + ] + else: + if document.id is not None: + doc_id = ( + f"{document.id}[{head_partition.start}:{head_partition.end}]" + f"+{document.id}[{tail_partition.start}:{tail_partition.end}]" + ) + else: + doc_id = None + new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id=doc_id, text=text, text_pair=text_pair + ) + + head_spans_mapping = { + span: _span_copy_shifted(span=span, offset=-head_partition.start) + for span in partition2spans[head_partition] + } + new_doc.labeled_spans.extend(head_spans_mapping.values()) + + tail_spans_mapping = { + span: _span_copy_shifted(span=span, offset=-tail_partition.start) + for span in partition2spans[tail_partition] + } + new_doc.labeled_spans_pair.extend(tail_spans_mapping.values()) + + texts2docs_and_span_mappings[(text, text_pair)] = ( + new_doc, + head_spans_mapping, + tail_spans_mapping, + ) + result.append(new_doc) + + coref_rel = BinaryCorefRelation( + head=head_spans_mapping[rel.head], tail=tail_spans_mapping[rel.tail], score=1.0 + ) + new_doc.binary_coref_relations.append(coref_rel) + + return result + + +def construct_text_pair_coref_documents_from_partitions_via_relations( + documents: Iterable[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions], **kwargs +) -> Iterator[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: + for doc in documents: + yield from _construct_text_pair_coref_documents_from_partitions_via_relations( + document=doc, **kwargs + ) def add_negative_coref_relations( diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py new file mode 100644 index 000000000..5192d87f9 --- /dev/null +++ b/tests/document/processing/test_text_pair.py @@ -0,0 +1,267 @@ +from itertools import chain +from typing import List + +import pytest +from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) + +from pie_modules.document.processing.text_pair import ( + add_negative_coref_relations, + construct_text_pair_coref_documents_from_partitions_via_relations, +) +from pie_modules.document.types import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) + +SENTENCES = [ + "Entity A works at B.", + "And she founded C.", + "Bob loves his cat.", + "She sleeps a lot.", +] + + +@pytest.fixture(scope="module") +def text_documents() -> List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]: + doc1 = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + id="doc1", text=" ".join(SENTENCES[:2]) + ) + # add sentence partitions + doc1.labeled_partitions.append(LabeledSpan(start=0, end=len(SENTENCES[0]), label="sentence")) + doc1.labeled_partitions.append( + LabeledSpan( + start=len(SENTENCES[0]) + 1, + end=len(SENTENCES[0]) + 1 + len(SENTENCES[1]), + label="sentence", + ) + ) + # add spans + doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) + doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) + doc1_sen2_offset = doc1.labeled_partitions[1].start + doc1.labeled_spans.append( + LabeledSpan(start=4 + doc1_sen2_offset, end=7 + doc1_sen2_offset, label="PERSON") + ) + doc1.labeled_spans.append( + LabeledSpan(start=16 + doc1_sen2_offset, end=17 + doc1_sen2_offset, label="COMPANY") + ) + # add relation + doc1.binary_relations.append( + BinaryRelation( + head=doc1.labeled_spans[0], tail=doc1.labeled_spans[2], label="semantically_same" + ) + ) + + doc2 = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + id="doc2", text=" ".join(SENTENCES[2:4]) + ) + # add sentence partitions + doc2.labeled_partitions.append(LabeledSpan(start=0, end=len(SENTENCES[2]), label="sentence")) + doc2.labeled_partitions.append( + LabeledSpan( + start=len(SENTENCES[2]) + 1, + end=len(SENTENCES[2]) + 1 + len(SENTENCES[3]), + label="sentence", + ) + ) + # add spans + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2_sen2_offset = doc2.labeled_partitions[1].start + doc2.labeled_spans.append( + LabeledSpan(start=0 + doc2_sen2_offset, end=3 + doc2_sen2_offset, label="ANIMAL") + ) + # add relation + doc2.binary_relations.append( + BinaryRelation( + head=doc2.labeled_spans[1], tail=doc2.labeled_spans[2], label="semantically_same" + ) + ) + + return [doc1, doc2] + + +def test_simple_text_documents(text_documents): + assert len(text_documents) == 2 + doc = text_documents[0] + # test serialization + doc.copy() + # test sentences + assert doc.labeled_partitions.resolve() == [ + ("sentence", "Entity A works at B."), + ("sentence", "And she founded C."), + ] + # test spans + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Entity A"), + ("COMPANY", "B"), + ("PERSON", "she"), + ("COMPANY", "C"), + ] + # test relation + assert doc.binary_relations.resolve() == [ + ("semantically_same", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + doc = text_documents[1] + # test serialization + doc.copy() + # test sentences + assert doc.labeled_partitions.resolve() == [ + ("sentence", "Bob loves his cat."), + ("sentence", "She sleeps a lot."), + ] + # test spans + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ("ANIMAL", "She"), + ] + # test relation + assert doc.binary_relations.resolve() == [ + ("semantically_same", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + +def test_construct_text_pair_coref_documents_from_partitions_via_relations(text_documents): + all_docs = { + doc.id: doc + for doc in construct_text_pair_coref_documents_from_partitions_via_relations( + documents=text_documents, relation_label="semantically_same" + ) + } + assert set(all_docs) == {"doc2[0:18]+doc2[19:36]", "doc1[0:20]+doc1[21:39]"} + + doc = all_docs["doc2[0:18]+doc2[19:36]"] + assert doc.text == "Bob loves his cat." + assert doc.text_pair == "She sleeps a lot." + assert doc.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] + assert doc.labeled_spans_pair.resolve() == [("ANIMAL", "She")] + assert doc.binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + doc = all_docs["doc1[0:20]+doc1[21:39]"] + assert doc.text == "Entity A works at B." + assert doc.text_pair == "And she founded C." + assert doc.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] + assert doc.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc.binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + +@pytest.fixture(scope="module") +def positive_documents(): + doc1 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Entity A works at B.", text_pair="And she founded C." + ) + doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) + doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) + doc1.labeled_spans_pair.append(LabeledSpan(start=4, end=7, label="PERSON")) + doc1.labeled_spans_pair.append(LabeledSpan(start=16, end=17, label="COMPANY")) + doc1.binary_coref_relations.append( + BinaryCorefRelation(head=doc1.labeled_spans[0], tail=doc1.labeled_spans_pair[0]) + ) + + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + doc2.binary_coref_relations.append( + BinaryCorefRelation(head=doc2.labeled_spans[1], tail=doc2.labeled_spans_pair[0]) + ) + + return [doc1, doc2] + + +def test_positive_documents(positive_documents): + assert len(positive_documents) == 2 + doc1, doc2 = positive_documents + assert doc1.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] + assert doc1.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc1.binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + assert doc2.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] + assert doc2.labeled_spans_pair.resolve() == [("ANIMAL", "She")] + assert doc2.binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + +def test_construct_negative_documents(positive_documents): + assert len(positive_documents) == 2 + docs = list(add_negative_coref_relations(positive_documents)) + TEXTS = [ + "Entity A works at B.", + "And she founded C.", + "Bob loves his cat.", + "She sleeps a lot.", + ] + assert all(doc.text in TEXTS for doc in docs) + assert all(doc.text_pair in TEXTS for doc in docs) + + all_texts = [(doc.text, doc.text_pair) for doc in docs] + all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] + all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] + + all_rels_and_scores = [ + (texts, list(zip(scores, rels_resolved))) + for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) + ] + + assert all_rels_and_scores == [ + (("And she founded C.", "And she founded C."), []), + ( + ("And she founded C.", "Bob loves his cat."), + [(0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob"))))], + ), + ( + ("And she founded C.", "Entity A works at B."), + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + ), + (("And she founded C.", "She sleeps a lot."), []), + ( + ("Bob loves his cat.", "And she founded C."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], + ), + (("Bob loves his cat.", "Bob loves his cat."), []), + ( + ("Bob loves his cat.", "Entity A works at B."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + ), + ( + ("Entity A works at B.", "Bob loves his cat."), + [(0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))))], + ), + (("Entity A works at B.", "Entity A works at B."), []), + (("Entity A works at B.", "She sleeps a lot."), []), + (("She sleeps a lot.", "And she founded C."), []), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + (("She sleeps a lot.", "Entity A works at B."), []), + (("She sleeps a lot.", "She sleeps a lot."), []), + ] diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 63425c2e0..42d62280a 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -88,77 +88,6 @@ def taskmodule(unprepared_taskmodule, positive_documents): return unprepared_taskmodule -def test_construct_negative_documents(positive_documents): - assert len(positive_documents) == 2 - docs = list(add_negative_coref_relations(positive_documents)) - TEXTS = [ - "Entity A works at B.", - "And she founded C.", - "Bob loves his cat.", - "She sleeps a lot.", - ] - assert all(doc.text in TEXTS for doc in docs) - assert all(doc.text_pair in TEXTS for doc in docs) - - all_texts = [(doc.text, doc.text_pair) for doc in docs] - all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] - all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] - - all_rels_and_scores = [ - (texts, list(zip(scores, rels_resolved))) - for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) - ] - - assert all_rels_and_scores == [ - (("And she founded C.", "And she founded C."), []), - ( - ("And she founded C.", "Bob loves his cat."), - [(0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob"))))], - ), - ( - ("And she founded C.", "Entity A works at B."), - [ - (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), - (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), - ], - ), - (("And she founded C.", "She sleeps a lot."), []), - ( - ("Bob loves his cat.", "And she founded C."), - [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], - ), - (("Bob loves his cat.", "Bob loves his cat."), []), - ( - ("Bob loves his cat.", "Entity A works at B."), - [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], - ), - ( - ("Bob loves his cat.", "She sleeps a lot."), - [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], - ), - ( - ("Entity A works at B.", "And she founded C."), - [ - (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), - (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), - ], - ), - ( - ("Entity A works at B.", "Bob loves his cat."), - [(0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))))], - ), - (("Entity A works at B.", "Entity A works at B."), []), - (("Entity A works at B.", "She sleeps a lot."), []), - (("She sleeps a lot.", "And she founded C."), []), - ( - ("She sleeps a lot.", "Bob loves his cat."), - [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], - ), - (("She sleeps a lot.", "Entity A works at B."), []), - (("She sleeps a lot.", "She sleeps a lot."), []), - ] - - @pytest.fixture(scope="module") def documents_with_negatives(taskmodule, positive_documents): file_name = ( From c0c28750243f8f4cd1e57d59781c484ae41ea6b6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Fri, 13 Sep 2024 18:39:37 +0200 Subject: [PATCH 31/49] add tqdm to add_negative_coref_relations --- src/pie_modules/document/processing/text_pair.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index e0e4fc446..69fcbc0d8 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -6,6 +6,7 @@ from pytorch_ie.documents import ( TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) +from tqdm import tqdm from pie_modules.document.types import ( BinaryCorefRelation, @@ -134,7 +135,7 @@ def add_negative_coref_relations( positive_tuples[(doc.text_pair, doc.text)].add((coref.tail.copy(), coref.head.copy())) new_docs = [] - for text in sorted(text2spans): + for text in tqdm(sorted(text2spans)): for text_pair in sorted(text2spans): current_positives = positive_tuples.get((text, text_pair), set()) new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( From 81ea67c05289396249b9aaf5820ebc2c84de5524 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 16:56:35 +0200 Subject: [PATCH 32/49] move document and annotation types to documents and annotations modules, respectively --- src/pie_modules/annotations.py | 5 ++ .../document/processing/text_pair.py | 2 +- src/pie_modules/document/types.py | 75 ------------------- src/pie_modules/documents.py | 61 +++++++++++++++ .../taskmodules/cross_text_binary_coref.py | 2 +- tests/document/processing/test_text_pair.py | 2 +- .../test_cross_text_binary_coref.py | 2 +- 7 files changed, 70 insertions(+), 79 deletions(-) delete mode 100644 src/pie_modules/document/types.py diff --git a/src/pie_modules/annotations.py b/src/pie_modules/annotations.py index 09a31bbee..208c2d1b9 100644 --- a/src/pie_modules/annotations.py +++ b/src/pie_modules/annotations.py @@ -63,3 +63,8 @@ class GenerativeAnswer(AnnotationWithText): score: Optional[float] = dataclasses.field(default=None, compare=False) question: Optional[Question] = None + + +@dataclasses.dataclass(eq=True, frozen=True) +class BinaryCorefRelation(BinaryRelation): + label: str = "coref" diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 69fcbc0d8..a5afc4c34 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -8,7 +8,7 @@ ) from tqdm import tqdm -from pie_modules.document.types import ( +from pie_modules.documents import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) diff --git a/src/pie_modules/document/types.py b/src/pie_modules/document/types.py deleted file mode 100644 index b327a7903..000000000 --- a/src/pie_modules/document/types.py +++ /dev/null @@ -1,75 +0,0 @@ -import dataclasses - -from pytorch_ie import AnnotationLayer, annotation_field -from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.documents import ( - TextBasedDocument, - TextDocumentWithLabeledPartitions, - TextDocumentWithLabeledSpans, - TextDocumentWithLabeledSpansAndLabeledPartitions, -) - - -@dataclasses.dataclass -class WithTextPair: - text_pair: str - - -@dataclasses.dataclass -class WithLabeledSpansPair(WithTextPair): - labeled_spans_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") - - -@dataclasses.dataclass -class WithLabeledPartitionsPair(WithTextPair): - labeled_partitions_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") - - -@dataclasses.dataclass -class TextPairBasedDocument(TextBasedDocument, WithTextPair): - pass - - -@dataclasses.dataclass -class TextPairDocumentWithLabeledPartitions( - WithLabeledPartitionsPair, TextPairBasedDocument, TextDocumentWithLabeledPartitions -): - pass - - -@dataclasses.dataclass -class TextPairDocumentWithLabeledSpans( - WithLabeledSpansPair, TextPairBasedDocument, TextDocumentWithLabeledSpans -): - pass - - -@dataclasses.dataclass -class TextPairDocumentWithLabeledSpansAndLabeledPartitions( - TextPairDocumentWithLabeledPartitions, - TextPairDocumentWithLabeledSpans, - TextDocumentWithLabeledSpansAndLabeledPartitions, -): - pass - - -@dataclasses.dataclass(eq=True, frozen=True) -class BinaryCorefRelation(BinaryRelation): - label: str = "coref" - - -@dataclasses.dataclass -class TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( - TextPairDocumentWithLabeledSpans, TextDocumentWithLabeledSpans -): - binary_coref_relations: AnnotationLayer[BinaryCorefRelation] = annotation_field( - targets=["labeled_spans", "labeled_spans_pair"] - ) - - -@dataclasses.dataclass -class TextPairDocumentWithLabeledSpansSimilarityRelationsAndLabeledPartitions( - TextPairDocumentWithLabeledSpansAndLabeledPartitions, - TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, -): - pass diff --git a/src/pie_modules/documents.py b/src/pie_modules/documents.py index 4dae18b78..3de89e850 100644 --- a/src/pie_modules/documents.py +++ b/src/pie_modules/documents.py @@ -27,6 +27,7 @@ from pie_modules.annotations import ( AbstractiveSummary, + BinaryCorefRelation, BinaryRelation, ExtractiveAnswer, GenerativeAnswer, @@ -151,3 +152,63 @@ class TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions( TokenDocumentWithLabeledMultiSpansAndBinaryRelations, ): pass + + +@dataclasses.dataclass +class WithTextPair: + text_pair: str + + +@dataclasses.dataclass +class WithLabeledSpansPair(WithTextPair): + labeled_spans_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") + + +@dataclasses.dataclass +class WithLabeledPartitionsPair(WithTextPair): + labeled_partitions_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") + + +@dataclasses.dataclass +class TextPairBasedDocument(TextBasedDocument, WithTextPair): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledPartitions( + WithLabeledPartitionsPair, TextPairBasedDocument, TextDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpans( + WithLabeledSpansPair, TextPairBasedDocument, TextDocumentWithLabeledSpans +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansAndLabeledPartitions( + TextPairDocumentWithLabeledPartitions, + TextPairDocumentWithLabeledSpans, + TextDocumentWithLabeledSpansAndLabeledPartitions, +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + TextPairDocumentWithLabeledSpans, TextDocumentWithLabeledSpans +): + binary_coref_relations: AnnotationLayer[BinaryCorefRelation] = annotation_field( + targets=["labeled_spans", "labeled_spans_pair"] + ) + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansSimilarityRelationsAndLabeledPartitions( + TextPairDocumentWithLabeledSpansAndLabeledPartitions, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +): + pass diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 5a8bfba4f..9f08556d1 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -23,7 +23,7 @@ from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias -from pie_modules.document.types import ( +from pie_modules.documents import ( TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index 5192d87f9..9d4aa1a20 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -11,7 +11,7 @@ add_negative_coref_relations, construct_text_pair_coref_documents_from_partitions_via_relations, ) -from pie_modules.document.types import ( +from pie_modules.documents import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 42d62280a..5555a0a3c 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -8,7 +8,7 @@ from torchmetrics import Metric, MetricCollection from pie_modules.document.processing.text_pair import add_negative_coref_relations -from pie_modules.document.types import ( +from pie_modules.documents import ( BinaryCorefRelation, TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) From b1ceac26a57d81da931734860e073d55548a8197 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 17:32:24 +0200 Subject: [PATCH 33/49] fix tokenization in encode_input --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 9f08556d1..25e87b411 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -149,9 +149,7 @@ def encode_input( self.collect_all_relations(kind="available", relations=document.binary_coref_relations) tokenizer_kwargs = dict( padding=False, - truncation=True, - max_length=self.tokenizer.model_max_length, - return_offsets_mapping=False, + truncation=False, add_special_tokens=False, ) encoding = self.tokenizer(text=document.text, **tokenizer_kwargs) From b5f0c8bd58df927a04530e5bf07d8fd4f6f93171 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 18:43:30 +0200 Subject: [PATCH 34/49] outsource get_aligned_token_span() and SpanNotAlignedWithTokenException to utils.tokenization module --- .../taskmodules/cross_text_binary_coref.py | 35 ++++++------ src/pie_modules/utils/tokenization.py | 35 ++++++++++++ tests/utils/test_tokenization.py | 55 +++++++++++++++++++ 3 files changed, 108 insertions(+), 17 deletions(-) create mode 100644 src/pie_modules/utils/tokenization.py create mode 100644 tests/utils/test_tokenization.py diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 25e87b411..3f039cec3 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -10,6 +10,7 @@ Sequence, Tuple, TypedDict, + TypeVar, Union, ) @@ -29,6 +30,10 @@ from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction from pie_modules.utils import list_of_dicts2dict_of_lists +from pie_modules.utils.tokenization import ( + SpanNotAlignedWithTokenException, + get_aligned_token_span, +) logger = logging.getLogger(__name__) @@ -63,11 +68,6 @@ class TaskOutputType(TypedDict, total=False): ] -class SpanNotAlignedWithTokenException(Exception): - def __init__(self, span): - self.span = span - - class SpanDoesNotFitIntoAvailableWindow(Exception): def __init__(self, span): self.span = span @@ -77,6 +77,13 @@ def _get_labels(model_output: ModelTargetType) -> torch.Tensor: return model_output["labels"] +S = TypeVar("S", bound=Span) + + +def shift_span(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + @TaskModule.register() class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): """This taskmodule processes documents of type @@ -114,32 +121,26 @@ def truncate_encoding_around_span( ) -> Tuple[Dict[str, List[int]], Span]: input_ids = copy.deepcopy(encoding["input_ids"]) - token_start = encoding.char_to_token(char_span.start) - token_end_before = encoding.char_to_token(char_span.end - 1) - if token_start is None or token_end_before is None: - raise SpanNotAlignedWithTokenException(span=char_span) - token_end = token_end_before + 1 + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) # truncate input_ids and shift token_start and token_end if len(input_ids) > self.available_window: window_slice = get_window_around_slice( - slice=[token_start, token_end], + slice=(token_span.start, token_span.end), max_window_size=self.available_window, available_input_length=len(input_ids), ) if window_slice is None: - raise SpanDoesNotFitIntoAvailableWindow(span=(token_start, token_end)) + raise SpanDoesNotFitIntoAvailableWindow(span=token_span) window_start, window_end = window_slice input_ids = input_ids[window_start:window_end] - token_start -= window_start - token_end -= window_start + token_span = shift_span(token_span, offset=-window_start) truncated_encoding = self.tokenizer.prepare_for_model(ids=input_ids) # shift indices because we added special tokens to the input_ids - token_start += self.num_special_tokens_before - token_end += self.num_special_tokens_before + token_span = shift_span(token_span, offset=self.num_special_tokens_before) - return truncated_encoding, Span(start=token_start, end=token_end) + return truncated_encoding, token_span def encode_input( self, diff --git a/src/pie_modules/utils/tokenization.py b/src/pie_modules/utils/tokenization.py new file mode 100644 index 000000000..addc2fb03 --- /dev/null +++ b/src/pie_modules/utils/tokenization.py @@ -0,0 +1,35 @@ +from typing import TypeVar + +from pytorch_ie.annotations import Span +from transformers import BatchEncoding + +S = TypeVar("S", bound=Span) + + +class SpanNotAlignedWithTokenException(Exception): + def __init__(self, span): + self.span = span + + +def get_aligned_token_span(encoding: BatchEncoding, char_span: S) -> S: + # find the start + token_start = None + token_end_before = None + char_start = None + for idx in range(char_span.start, char_span.end): + token_start = encoding.char_to_token(idx) + if token_start is not None: + char_start = idx + break + + if char_start is None: + raise SpanNotAlignedWithTokenException(span=char_span) + for idx in range(char_span.end - 1, char_start - 1, -1): + token_end_before = encoding.char_to_token(idx) + if token_end_before is not None: + break + + if token_start is None or token_end_before is None: + raise SpanNotAlignedWithTokenException(span=char_span) + + return char_span.copy(start=token_start, end=token_end_before + 1) diff --git a/tests/utils/test_tokenization.py b/tests/utils/test_tokenization.py new file mode 100644 index 000000000..0b51ce62d --- /dev/null +++ b/tests/utils/test_tokenization.py @@ -0,0 +1,55 @@ +import pytest +from pytorch_ie.annotations import Span +from transformers import AutoTokenizer + +from pie_modules.utils.tokenization import ( + SpanNotAlignedWithTokenException, + get_aligned_token_span, +) + + +def test_get_aligned_token_span(): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + + text = "Hello, world!" + encoding = tokenizer(text) + tokens = tokenizer.convert_ids_to_tokens(encoding.input_ids) + assert tokens == ["[CLS]", "Hello", ",", "world", "!", "[SEP]"] + + # already aligned + char_span = Span(0, 5) + assert text[char_span.start : char_span.end] == "Hello" + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == ["Hello"] + + # end not aligned + char_span = Span(5, 7) + assert text[char_span.start : char_span.end] == ", " + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == [","] + + # start not aligned + char_span = Span(6, 12) + assert text[char_span.start : char_span.end] == " world" + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == ["world"] + + # start not aligned, end inside token + char_span = Span(6, 8) + assert text[char_span.start : char_span.end] == " w" + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == ["world"] + + # empty char span + char_span = Span(2, 2) + assert text[char_span.start : char_span.end] == "" + with pytest.raises(SpanNotAlignedWithTokenException) as e: + get_aligned_token_span(encoding=encoding, char_span=char_span) + assert e.value.span == char_span + + # empty token span + char_span = Span(6, 7) + assert text[char_span.start : char_span.end] == " " + with pytest.raises(SpanNotAlignedWithTokenException) as e: + get_aligned_token_span(encoding=encoding, char_span=char_span) + assert e.value.span == char_span From 7487fd730c4b15d8670d59921a62c636ea05c812 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 19:41:53 +0200 Subject: [PATCH 35/49] implement construct_text_document_from_text_pair_coref_document() --- .../document/processing/text_pair.py | 46 ++++++++++++++ tests/document/processing/test_text_pair.py | 62 ++++++++++++++++--- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index a5afc4c34..a9902293f 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,9 +1,12 @@ +import copy from collections import defaultdict from collections.abc import Iterator +from itertools import chain from typing import Dict, Iterable, List, Tuple, TypeVar from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansAndBinaryRelations, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from tqdm import tqdm @@ -119,6 +122,49 @@ def construct_text_pair_coref_documents_from_partitions_via_relations( ) +def shift_span(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +def construct_text_document_from_text_pair_coref_document( + document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str +) -> TextDocumentWithLabeledSpansAndBinaryRelations: + if document.text == document.text_pair: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text + ) + old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + for old_span in chain(document.labeled_spans, document.labeled_spans_pair): + new_span = old_span.copy() + # when detaching / copying the span, it may be the same as a previous span from the other + new_span = new2new_spans.get(new_span, new_span) + new2new_spans[new_span] = new_span + old2new_spans[old_span] = new_span + else: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + text=document.text + glue_text + document.text_pair, + id=document.id, + metadata=copy.deepcopy(document.metadata), + ) + old2new_spans = {} + old2new_spans.update({span: span.copy() for span in document.labeled_spans}) + offset = len(document.text) + len(glue_text) + old2new_spans.update( + {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} + ) + + # sort to make order deterministic + new_doc.labeled_spans.extend( + sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) + ) + for old_rel in document.binary_coref_relations: + new_rel = old_rel.copy(head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail]) + new_doc.binary_relations.append(new_rel) + + return new_doc + + def add_negative_coref_relations( documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs ) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index 9d4aa1a20..c4ade49ee 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -9,6 +9,7 @@ from pie_modules.document.processing.text_pair import ( add_negative_coref_relations, + construct_text_document_from_text_pair_coref_document, construct_text_pair_coref_documents_from_partitions_via_relations, ) from pie_modules.documents import ( @@ -196,21 +197,31 @@ def test_positive_documents(positive_documents): ] -def test_construct_negative_documents(positive_documents): - assert len(positive_documents) == 2 +@pytest.fixture(scope="module") +def positive_and_negative_documents(positive_documents): docs = list(add_negative_coref_relations(positive_documents)) + return docs + + +def test_construct_negative_documents(positive_and_negative_documents): + assert len(positive_and_negative_documents) == 16 TEXTS = [ "Entity A works at B.", "And she founded C.", "Bob loves his cat.", "She sleeps a lot.", ] - assert all(doc.text in TEXTS for doc in docs) - assert all(doc.text_pair in TEXTS for doc in docs) + assert all(doc.text in TEXTS for doc in positive_and_negative_documents) + assert all(doc.text_pair in TEXTS for doc in positive_and_negative_documents) - all_texts = [(doc.text, doc.text_pair) for doc in docs] - all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] - all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] + all_texts = [(doc.text, doc.text_pair) for doc in positive_and_negative_documents] + all_scores = [ + [coref_rel.score for coref_rel in doc.binary_coref_relations] + for doc in positive_and_negative_documents + ] + all_rels_resolved = [ + doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents + ] all_rels_and_scores = [ (texts, list(zip(scores, rels_resolved))) @@ -265,3 +276,40 @@ def test_construct_negative_documents(positive_documents): (("She sleeps a lot.", "Entity A works at B."), []), (("She sleeps a lot.", "She sleeps a lot."), []), ] + + +def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): + glue_text = "" + docs = [ + construct_text_document_from_text_pair_coref_document(doc, glue_text=glue_text) + for doc in positive_and_negative_documents + ] + assert len(docs) == 16 + doc = docs[0] + assert doc.text == "And she founded C." + assert doc.labeled_spans.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc.binary_relations.resolve() == [] + assert [rel.score for rel in doc.binary_relations] == [] + + doc = docs[1] + assert doc.text == "And she founded C.Bob loves his cat." + assert doc.labeled_spans.resolve() == [ + ("PERSON", "she"), + ("COMPANY", "C"), + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ] + assert doc.binary_relations.resolve() == [("coref", (("PERSON", "she"), ("PERSON", "Bob")))] + assert [rel.score for rel in doc.binary_relations] == [0.0] + + doc = docs[7] + assert doc.text == "Bob loves his cat.She sleeps a lot." + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ("ANIMAL", "She"), + ] + assert doc.binary_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + assert [rel.score for rel in doc.binary_relations] == [1.0] From 961a2a074d28a47af24202b05d9e5e187d643547 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 19:51:55 +0200 Subject: [PATCH 36/49] add parameter relation_label_mapping to construct_text_document_from_text_pair_coref_document() --- src/pie_modules/document/processing/text_pair.py | 13 ++++++++++--- tests/document/processing/test_text_pair.py | 10 +++++++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index a9902293f..4387177e6 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -2,7 +2,7 @@ from collections import defaultdict from collections.abc import Iterator from itertools import chain -from typing import Dict, Iterable, List, Tuple, TypeVar +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.documents import ( @@ -127,7 +127,9 @@ def shift_span(span: S, offset: int) -> S: def construct_text_document_from_text_pair_coref_document( - document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str + document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, + glue_text: str, + relation_label_mapping: Optional[Dict[str, str]] = None, ) -> TextDocumentWithLabeledSpansAndBinaryRelations: if document.text == document.text_pair: new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( @@ -159,7 +161,12 @@ def construct_text_document_from_text_pair_coref_document( sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) ) for old_rel in document.binary_coref_relations: - new_rel = old_rel.copy(head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail]) + label = old_rel.label + if relation_label_mapping is not None: + label = relation_label_mapping.get(label, label) + new_rel = old_rel.copy( + head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail], label=label + ) new_doc.binary_relations.append(new_rel) return new_doc diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index c4ade49ee..a9c589545 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -281,7 +281,9 @@ def test_construct_negative_documents(positive_and_negative_documents): def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): glue_text = "" docs = [ - construct_text_document_from_text_pair_coref_document(doc, glue_text=glue_text) + construct_text_document_from_text_pair_coref_document( + doc, glue_text=glue_text, relation_label_mapping={"coref": "semantically_same"} + ) for doc in positive_and_negative_documents ] assert len(docs) == 16 @@ -299,7 +301,9 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega ("PERSON", "Bob"), ("ANIMAL", "his cat"), ] - assert doc.binary_relations.resolve() == [("coref", (("PERSON", "she"), ("PERSON", "Bob")))] + assert doc.binary_relations.resolve() == [ + ("semantically_same", (("PERSON", "she"), ("PERSON", "Bob"))) + ] assert [rel.score for rel in doc.binary_relations] == [0.0] doc = docs[7] @@ -310,6 +314,6 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega ("ANIMAL", "She"), ] assert doc.binary_relations.resolve() == [ - ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ("semantically_same", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) ] assert [rel.score for rel in doc.binary_relations] == [1.0] From 1e3b5d431b121d62be536f4f4030495e1d24d435 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 20:02:29 +0200 Subject: [PATCH 37/49] add parameter no_relation_label to construct_text_document_from_text_pair_coref_document() --- src/pie_modules/document/processing/text_pair.py | 8 ++++++-- tests/document/processing/test_text_pair.py | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 4387177e6..cf3885b21 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -129,6 +129,7 @@ def shift_span(span: S, offset: int) -> S: def construct_text_document_from_text_pair_coref_document( document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str, + no_relation_label: str, relation_label_mapping: Optional[Dict[str, str]] = None, ) -> TextDocumentWithLabeledSpansAndBinaryRelations: if document.text == document.text_pair: @@ -161,11 +162,14 @@ def construct_text_document_from_text_pair_coref_document( sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) ) for old_rel in document.binary_coref_relations: - label = old_rel.label + label = old_rel.label if old_rel.score > 0.0 else no_relation_label if relation_label_mapping is not None: label = relation_label_mapping.get(label, label) new_rel = old_rel.copy( - head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail], label=label + head=old2new_spans[old_rel.head], + tail=old2new_spans[old_rel.tail], + label=label, + score=1.0, ) new_doc.binary_relations.append(new_rel) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index a9c589545..cd016a117 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -282,7 +282,10 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega glue_text = "" docs = [ construct_text_document_from_text_pair_coref_document( - doc, glue_text=glue_text, relation_label_mapping={"coref": "semantically_same"} + doc, + glue_text=glue_text, + no_relation_label="no_relation", + relation_label_mapping={"coref": "semantically_same"}, ) for doc in positive_and_negative_documents ] @@ -302,9 +305,9 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega ("ANIMAL", "his cat"), ] assert doc.binary_relations.resolve() == [ - ("semantically_same", (("PERSON", "she"), ("PERSON", "Bob"))) + ("no_relation", (("PERSON", "she"), ("PERSON", "Bob"))) ] - assert [rel.score for rel in doc.binary_relations] == [0.0] + assert [rel.score for rel in doc.binary_relations] == [1.0] doc = docs[7] assert doc.text == "Bob loves his cat.She sleeps a lot." From b8d99394138f699fdeaba66923384e9133cb9df4 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 21:52:48 +0200 Subject: [PATCH 38/49] prepare downsampling of negatives --- .../document/processing/text_pair.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index cf3885b21..6b85e51a8 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -192,6 +192,9 @@ def add_negative_coref_relations( positive_tuples[(doc.text_pair, doc.text)].add((coref.tail.copy(), coref.head.copy())) new_docs = [] + new_rels2new_docs = {} + positive_rels = [] + negative_rels = [] for text in tqdm(sorted(text2spans)): for text_pair in sorted(text2spans): current_positives = positive_tuples.get((text, text_pair), set()) @@ -211,7 +214,21 @@ def add_negative_coref_relations( continue score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) - new_doc.binary_coref_relations.append(new_coref_rel) + # new_doc.binary_coref_relations.append(new_coref_rel) + new_rels2new_docs[new_coref_rel] = new_doc + if score > 0.0: + positive_rels.append(new_coref_rel) + else: + negative_rels.append(new_coref_rel) new_docs.append(new_doc) + for rel in positive_rels: + new_rels2new_docs[rel].binary_coref_relations.append(rel) + + # TODO: implement down sampling + for rel in negative_rels: + new_rels2new_docs[rel].binary_coref_relations.append(rel) + + # docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0] + # return docs_with_rels return new_docs From 1d73f8a42a93fc0ff6f8a2ca192f99eb09c41bea Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 22:04:03 +0200 Subject: [PATCH 39/49] add_negative_coref_relations does not return docs without relations --- .../document/processing/text_pair.py | 7 +++---- tests/document/processing/test_text_pair.py | 20 +++---------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 6b85e51a8..e052342cc 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -225,10 +225,9 @@ def add_negative_coref_relations( for rel in positive_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) - # TODO: implement down sampling + # TODO: implement downsampling for rel in negative_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) - # docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0] - # return docs_with_rels - return new_docs + docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0] + return docs_with_rels diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index cd016a117..e452219de 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -204,7 +204,7 @@ def positive_and_negative_documents(positive_documents): def test_construct_negative_documents(positive_and_negative_documents): - assert len(positive_and_negative_documents) == 16 + assert len(positive_and_negative_documents) == 8 TEXTS = [ "Entity A works at B.", "And she founded C.", @@ -229,7 +229,6 @@ def test_construct_negative_documents(positive_and_negative_documents): ] assert all_rels_and_scores == [ - (("And she founded C.", "And she founded C."), []), ( ("And she founded C.", "Bob loves his cat."), [(0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob"))))], @@ -241,12 +240,10 @@ def test_construct_negative_documents(positive_and_negative_documents): (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), ], ), - (("And she founded C.", "She sleeps a lot."), []), ( ("Bob loves his cat.", "And she founded C."), [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], ), - (("Bob loves his cat.", "Bob loves his cat."), []), ( ("Bob loves his cat.", "Entity A works at B."), [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], @@ -266,15 +263,10 @@ def test_construct_negative_documents(positive_and_negative_documents): ("Entity A works at B.", "Bob loves his cat."), [(0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))))], ), - (("Entity A works at B.", "Entity A works at B."), []), - (("Entity A works at B.", "She sleeps a lot."), []), - (("She sleeps a lot.", "And she founded C."), []), ( ("She sleeps a lot.", "Bob loves his cat."), [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], ), - (("She sleeps a lot.", "Entity A works at B."), []), - (("She sleeps a lot.", "She sleeps a lot."), []), ] @@ -289,14 +281,8 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega ) for doc in positive_and_negative_documents ] - assert len(docs) == 16 + assert len(docs) == 8 doc = docs[0] - assert doc.text == "And she founded C." - assert doc.labeled_spans.resolve() == [("PERSON", "she"), ("COMPANY", "C")] - assert doc.binary_relations.resolve() == [] - assert [rel.score for rel in doc.binary_relations] == [] - - doc = docs[1] assert doc.text == "And she founded C.Bob loves his cat." assert doc.labeled_spans.resolve() == [ ("PERSON", "she"), @@ -309,7 +295,7 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega ] assert [rel.score for rel in doc.binary_relations] == [1.0] - doc = docs[7] + doc = docs[4] assert doc.text == "Bob loves his cat.She sleeps a lot." assert doc.labeled_spans.resolve() == [ ("PERSON", "Bob"), From 11de48b182aed93d34dd8f11d504896c0a1e31b0 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 22:41:54 +0200 Subject: [PATCH 40/49] implement downsampling for add_negative_coref_relations() --- .../document/processing/text_pair.py | 22 ++++- tests/document/processing/test_text_pair.py | 88 +++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index e052342cc..1117b349a 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,4 +1,5 @@ import copy +import random from collections import defaultdict from collections.abc import Iterator from itertools import chain @@ -177,7 +178,9 @@ def construct_text_document_from_text_pair_coref_document( def add_negative_coref_relations( - documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs + documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], + downsampling_factor: Optional[float] = None, + **kwargs, ) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: positive_tuples = defaultdict(set) text2spans = defaultdict(set) @@ -225,7 +228,22 @@ def add_negative_coref_relations( for rel in positive_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) - # TODO: implement downsampling + # Downsampling of negatives. This requires positive instances! + if downsampling_factor is not None: + if len(positive_rels) == 0: + raise ValueError( + f"downsampling [factor={downsampling_factor}] is enabled, " + f"but no positive relations are available to calculate max_num_negative" + ) + + max_num_negative = int(len(positive_rels) * downsampling_factor) + if max_num_negative == 0: + raise ValueError( + f"downsampling with factor={downsampling_factor} and number of " + f"positive relations={len(positive_rels)} does not produce any negatives" + ) + random.shuffle(negative_rels) + negative_rels = negative_rels[:max_num_negative] for rel in negative_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index e452219de..b64382180 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -1,3 +1,4 @@ +import random from itertools import chain from typing import List @@ -223,6 +224,16 @@ def test_construct_negative_documents(positive_and_negative_documents): doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents ] + # check number of all relations + all_rels_flat = [ + rel for doc in positive_and_negative_documents for rel in doc.binary_coref_relations + ] + assert len(all_rels_flat) == 10 + # positives + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 6 + all_rels_and_scores = [ (texts, list(zip(scores, rels_resolved))) for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) @@ -270,6 +281,83 @@ def test_construct_negative_documents(positive_and_negative_documents): ] +def test_construct_negative_documents_with_downsampling(positive_documents): + # set fixed seed because the negatives will get shuffled + random.seed(42) + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) + all_texts = [(doc.text, doc.text_pair) for doc in docs] + all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] + all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] + + all_rels_and_scores = [ + (texts, list(zip(scores, rels_resolved))) + for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) + ] + + # check number relations + all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] + # positives + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives (same number positives because downsampling_factor=1.0) + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 4 + + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Entity A works at B."), + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + ), + ( + ("Bob loves his cat.", "And she founded C."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], + ), + ( + ("Bob loves his cat.", "Entity A works at B."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + + # no positives + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + with pytest.raises(ValueError) as e: + list(add_negative_coref_relations([doc2], downsampling_factor=1.0)) + assert ( + str(e.value) + == "downsampling [factor=1.0] is enabled, but no positive relations are available to calculate " + "max_num_negative" + ) + + # sampling target is too low + with pytest.raises(ValueError) as e: + list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) + assert ( + str(e.value) + == "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" + ) + + def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): glue_text = "" docs = [ From 538e9ee1b284991594466fc562d9653340a0dcd1 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sun, 15 Sep 2024 22:42:22 +0200 Subject: [PATCH 41/49] remove unused kwargs from add_negative_coref_relations() --- src/pie_modules/document/processing/text_pair.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 1117b349a..9808d042e 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -180,7 +180,6 @@ def construct_text_document_from_text_pair_coref_document( def add_negative_coref_relations( documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], downsampling_factor: Optional[float] = None, - **kwargs, ) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: positive_tuples = defaultdict(set) text2spans = defaultdict(set) From 833eb8ca791d69a936a1ed3449fc83bd3985b517 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 14:06:02 +0200 Subject: [PATCH 42/49] allow that downsampling negatives does nto produce negatives at all --- .../document/processing/text_pair.py | 8 ++- tests/document/processing/test_text_pair.py | 54 ++++++++++++++----- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 9808d042e..756619eda 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,4 +1,5 @@ import copy +import logging import random from collections import defaultdict from collections.abc import Iterator @@ -18,6 +19,8 @@ ) from pie_modules.utils.span import are_nested +logger = logging.getLogger(__name__) + S = TypeVar("S", bound=Span) S2 = TypeVar("S2", bound=Span) @@ -237,11 +240,12 @@ def add_negative_coref_relations( max_num_negative = int(len(positive_rels) * downsampling_factor) if max_num_negative == 0: - raise ValueError( + logger.warning( f"downsampling with factor={downsampling_factor} and number of " f"positive relations={len(positive_rels)} does not produce any negatives" ) - random.shuffle(negative_rels) + else: + random.shuffle(negative_rels) negative_rels = negative_rels[:max_num_negative] for rel in negative_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index b64382180..dd205ce7d 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -281,10 +281,7 @@ def test_construct_negative_documents(positive_and_negative_documents): ] -def test_construct_negative_documents_with_downsampling(positive_documents): - # set fixed seed because the negatives will get shuffled - random.seed(42) - docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) +def _get_all_all_rels_and_scores(docs): all_texts = [(doc.text, doc.text_pair) for doc in docs] all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] @@ -293,6 +290,14 @@ def test_construct_negative_documents_with_downsampling(positive_documents): (texts, list(zip(scores, rels_resolved))) for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) ] + return all_rels_and_scores + + +def test_construct_negative_documents_with_downsampling(positive_documents, caplog): + # set fixed seed because the negatives will get shuffled + random.seed(42) + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) + all_rels_and_scores = _get_all_all_rels_and_scores(docs) # check number relations all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] @@ -334,6 +339,39 @@ def test_construct_negative_documents_with_downsampling(positive_documents): ), ] + # sampling target is too low + caplog.clear() + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) + assert caplog.messages == [ + "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" + ] + # check number relations + all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] + # positives: 2 x number of positives (we add instances with swapped texts) + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 0 + # check actual content + all_rels_and_scores = _get_all_all_rels_and_scores(docs) + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Entity A works at B."), + [(1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [(1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))))], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + # no positives doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." @@ -349,14 +387,6 @@ def test_construct_negative_documents_with_downsampling(positive_documents): "max_num_negative" ) - # sampling target is too low - with pytest.raises(ValueError) as e: - list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) - assert ( - str(e.value) - == "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" - ) - def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): glue_text = "" From e72fcc1bd1d496c3b203c795f820c0f1cc9d63aa Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 15:22:11 +0200 Subject: [PATCH 43/49] fix existing and add more metrics; rename "labels" / probabilities" to "scores" in model input / output --- .../sequence_classification_with_pooler.py | 8 +- .../taskmodules/cross_text_binary_coref.py | 46 ++++++-- ...uence_pair_similarity_model_with_pooler.py | 8 +- .../test_cross_text_binary_coref.py | 104 ++++++++++++++---- 4 files changed, 127 insertions(+), 39 deletions(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index be9e39f6c..5e198b4a5 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -330,7 +330,7 @@ def forward( result = {"logits": logits} if targets is not None: - labels = targets["labels"] + labels = targets["scores"] loss = self.loss_fct(logits, labels) result["loss"] = loss if return_hidden_states: @@ -340,6 +340,6 @@ def forward( def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: # probabilities = torch.sigmoid(outputs.logits) - probabilities = outputs.logits - labels = (probabilities > self.multi_label_threshold).to(torch.long) - return {"labels": labels, "probabilities": probabilities} + scores = outputs.logits + labels = (scores > self.multi_label_threshold).to(torch.long) + return {"labels": labels, "scores": scores} diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 3f039cec3..e3de1bfc4 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -19,7 +19,7 @@ from pytorch_ie.annotations import Span from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.utils.window import get_window_around_slice -from torchmetrics import Metric, MetricCollection +from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection from torchmetrics.classification import BinaryAUROC from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias @@ -77,6 +77,10 @@ def _get_labels(model_output: ModelTargetType) -> torch.Tensor: return model_output["labels"] +def _get_scores(model_output: ModelTargetType) -> torch.Tensor: + return model_output["scores"] + + S = TypeVar("S", bound=Span) @@ -240,22 +244,46 @@ def collate( if not task_encodings[0].has_targets: return inputs, None targets = { - "labels": torch.tensor([task_encoding.targets for task_encoding in task_encodings]) + "scores": torch.tensor([task_encoding.targets for task_encoding in task_encodings]) } return inputs, targets - def configure_model_metric(self, stage: str) -> Metric: - return WrappedMetricWithPrepareFunction( - metric=MetricCollection({"auroc": BinaryAUROC(thresholds=None)}), - prepare_function=_get_labels, + def configure_model_metric(self, stage: str) -> MetricCollection: + # we use the length of label_to_id because that contains the none_label (in contrast to labels) + labels = ["no_relation", "coref"] + common_metric_kwargs = { + "num_classes": len(labels), + "task": "multiclass", + } + + return MetricCollection( + metrics={ + "continuous": WrappedMetricWithPrepareFunction( + metric=MetricCollection({"auroc": BinaryAUROC(thresholds=None)}), + prepare_function=_get_scores, + ), + "discrete": WrappedMetricWithPrepareFunction( + metric=MetricCollection( + { + "micro/f1": F1Score(average="micro", **common_metric_kwargs), + "macro/f1": F1Score(average="macro", **common_metric_kwargs), + "f1_per_label": ClasswiseWrapper( + F1Score(average=None, **common_metric_kwargs), + labels=labels, + postfix="/f1", + ), + } + ), + prepare_function=_get_labels, + ), + } ) def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: label_ids = model_output["labels"].detach().cpu().tolist() - probabilities = model_output["probabilities"].detach().cpu().tolist() + scores = model_output["scores"].detach().cpu().tolist() result: List[TaskOutputType] = [ - {"is_valid": label_id != 0, "score": prob} - for label_id, prob in zip(label_ids, probabilities) + {"is_valid": label_id != 0, "score": prob} for label_id, prob in zip(label_ids, scores) ] return result diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py index 6bd48b6bb..bef436d68 100644 --- a/tests/models/test_sequence_pair_similarity_model_with_pooler.py +++ b/tests/models/test_sequence_pair_similarity_model_with_pooler.py @@ -78,7 +78,7 @@ def inputs() -> Dict[str, LongTensor]: @pytest.fixture def targets() -> Dict[str, LongTensor]: - return {"labels": tensor([0.0, 0.0, 0.0, 0.0])} + return {"scores": tensor([0.0, 0.0, 0.0, 0.0])} @pytest.fixture @@ -169,15 +169,15 @@ def test_forward_logits(model_output, inputs): def test_decode(model, model_output, inputs): decoded = model.decode(inputs=inputs, outputs=model_output) assert isinstance(decoded, dict) - assert set(decoded) == {"labels", "probabilities"} + assert set(decoded) == {"labels", "scores"} labels = decoded["labels"] torch.testing.assert_close( labels, torch.tensor([1, 1, 1, 1]), ) - probabilities = decoded["probabilities"] + scores = decoded["scores"] torch.testing.assert_close( - probabilities, + scores, torch.tensor( [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] ), diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 5555a0a3c..bb3c5b0f0 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -5,6 +5,7 @@ import pytest import torch.testing from pytorch_ie.annotations import LabeledSpan +from torch import tensor from torchmetrics import Metric, MetricCollection from pie_modules.document.processing.text_pair import add_negative_coref_relations @@ -273,14 +274,14 @@ def test_collate(batch, taskmodule): torch.testing.assert_close(inputs["pooler_pair_start_indices"], torch.tensor([[1], [7]])) torch.testing.assert_close(inputs["pooler_pair_end_indices"], torch.tensor([[5], [8]])) - torch.testing.assert_close(targets, {"labels": torch.tensor([1.0, 0.0])}) + torch.testing.assert_close(targets, {"scores": torch.tensor([1.0, 0.0])}) @pytest.fixture(scope="module") def unbatched_output(taskmodule): model_output = { "labels": torch.tensor([0, 1]), - "probabilities": torch.tensor([0.5338148474693298, 0.9866107940673828]), + "scores": torch.tensor([0.5338148474693298, 0.9866107940673828]), } return taskmodule.unbatch_output(model_output=model_output) @@ -323,46 +324,105 @@ def test_configure_metric(taskmodule, batch): assert isinstance(metric, (Metric, MetricCollection)) state = get_metric_state(metric) - assert state == {"auroc/preds": [], "auroc/target": []} + torch.testing.assert_close( + state, + { + "continuous/auroc/preds": [], + "continuous/auroc/target": [], + "discrete/f1_per_label/tp": tensor([0, 0]), + "discrete/f1_per_label/fp": tensor([0, 0]), + "discrete/f1_per_label/tn": tensor([0, 0]), + "discrete/f1_per_label/fn": tensor([0, 0]), + "discrete/macro/f1/tp": tensor([0, 0]), + "discrete/macro/f1/fp": tensor([0, 0]), + "discrete/macro/f1/tn": tensor([0, 0]), + "discrete/macro/f1/fn": tensor([0, 0]), + "discrete/micro/f1/tp": tensor([0]), + "discrete/micro/f1/fp": tensor([0]), + "discrete/micro/f1/tn": tensor([0]), + "discrete/micro/f1/fn": tensor([0]), + }, + ) # targets = batch[1] - targets = {"labels": torch.tensor([0.0, 1.0, 0.0, 0.0])} + targets = { + "labels": torch.tensor([0, 1, 0, 0]), + "scores": torch.tensor([0.0, 1.0, 0.0, 0.0]), + } metric.update(targets, targets) state = get_metric_state(metric) torch.testing.assert_close( state, { - "auroc/preds": [torch.tensor([0.0, 1.0, 0.0, 0.0])], - "auroc/target": [torch.tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/auroc/preds": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/auroc/target": [tensor([0.0, 1.0, 0.0, 0.0])], + "discrete/f1_per_label/tp": tensor([3, 1]), + "discrete/f1_per_label/fp": tensor([0, 0]), + "discrete/f1_per_label/tn": tensor([1, 3]), + "discrete/f1_per_label/fn": tensor([0, 0]), + "discrete/macro/f1/tp": tensor([3, 1]), + "discrete/macro/f1/fp": tensor([0, 0]), + "discrete/macro/f1/tn": tensor([1, 3]), + "discrete/macro/f1/fn": tensor([0, 0]), + "discrete/micro/f1/tp": tensor([4]), + "discrete/micro/f1/fp": tensor([0]), + "discrete/micro/f1/tn": tensor([4]), + "discrete/micro/f1/fn": tensor([0]), }, ) - assert metric.compute() == {"auroc": torch.tensor(1.0)} + torch.testing.assert_close( + metric.compute(), + { + "auroc": tensor(1.0), + "no_relation/f1": tensor(1.0), + "coref/f1": tensor(1.0), + "macro/f1": tensor(1.0), + "micro/f1": tensor(1.0), + }, + ) # torch.rand_like(targets) - random_targets = {"labels": torch.tensor([0.2703, 0.6812, 0.2582, 0.8030])} + random_targets = { + "labels": torch.tensor([0, 0, 0, 1]), + "scores": torch.tensor([0.2703, 0.6812, 0.2582, 0.8030]), + } metric.update(random_targets, targets) state = get_metric_state(metric) torch.testing.assert_close( state, { - "auroc/preds": [ - torch.tensor([0.0, 1.0, 0.0, 0.0]), - torch.tensor( - [ - 0.2703000009059906, - 0.6812000274658203, - 0.2581999897956848, - 0.8029999732971191, - ] - ), + "continuous/auroc/preds": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.2703, 0.6812, 0.2582, 0.8030]), ], - "auroc/target": [ - torch.tensor([0.0, 1.0, 0.0, 0.0]), - torch.tensor([0.0, 1.0, 0.0, 0.0]), + "continuous/auroc/target": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.0, 1.0, 0.0, 0.0]), ], + "discrete/f1_per_label/tp": tensor([5, 1]), + "discrete/f1_per_label/fp": tensor([1, 1]), + "discrete/f1_per_label/tn": tensor([1, 5]), + "discrete/f1_per_label/fn": tensor([1, 1]), + "discrete/macro/f1/tp": tensor([5, 1]), + "discrete/macro/f1/fp": tensor([1, 1]), + "discrete/macro/f1/tn": tensor([1, 5]), + "discrete/macro/f1/fn": tensor([1, 1]), + "discrete/micro/f1/tp": tensor([6]), + "discrete/micro/f1/fp": tensor([2]), + "discrete/micro/f1/tn": tensor([6]), + "discrete/micro/f1/fn": tensor([2]), }, ) - assert metric.compute() == {"auroc": torch.tensor(0.9166666269302368)} + torch.testing.assert_close( + metric.compute(), + { + "auroc": tensor(0.916667), + "no_relation/f1": tensor(0.833333), + "coref/f1": tensor(0.500000), + "macro/f1": tensor(0.666667), + "micro/f1": tensor(0.750000), + }, + ) From c2944291b72cb7c0db6cd9bb3d725035d609a032 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 15:39:31 +0200 Subject: [PATCH 44/49] move label_threshold from model to taskmodule; rename "is_valid" to "is_similar" in TaskOutputType --- .../sequence_classification_with_pooler.py | 5 +---- .../taskmodules/cross_text_binary_coref.py | 17 ++++++++++------- ...equence_pair_similarity_model_with_pooler.py | 8 +------- .../taskmodules/test_cross_text_binary_coref.py | 11 ++++------- 4 files changed, 16 insertions(+), 25 deletions(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index 5e198b4a5..cc5c2ea39 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -286,7 +286,6 @@ class SequencePairSimilarityModelWithPooler( def __init__( self, - label_threshold: float = 0.9, pooler: Optional[Union[Dict[str, Any], str]] = None, **kwargs, ): @@ -294,7 +293,6 @@ def __init__( # use (max) mention pooling per default pooler = {"type": "mention_pooling", "num_indices": 1} super().__init__(pooler=pooler, **kwargs) - self.multi_label_threshold = label_threshold def setup_classifier( self, pooler_output_dim: int @@ -341,5 +339,4 @@ def forward( def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: # probabilities = torch.sigmoid(outputs.logits) scores = outputs.logits - labels = (scores > self.multi_label_threshold).to(torch.long) - return {"labels": labels, "scores": scores} + return {"scores": scores} diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index e3de1bfc4..b8bc5f1a6 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -1,5 +1,6 @@ import copy import logging +from functools import partial from typing import ( Any, Dict, @@ -50,7 +51,7 @@ class TaskOutputType(TypedDict, total=False): score: float - is_valid: bool + is_similar: bool ModelInputType: TypeAlias = Dict[str, torch.Tensor] @@ -73,8 +74,8 @@ def __init__(self, span): self.span = span -def _get_labels(model_output: ModelTargetType) -> torch.Tensor: - return model_output["labels"] +def _get_labels(model_output: ModelTargetType, label_threshold: float) -> torch.Tensor: + return (model_output["scores"] > label_threshold).to(torch.int) def _get_scores(model_output: ModelTargetType) -> torch.Tensor: @@ -99,6 +100,7 @@ class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): def __init__( self, tokenizer_name_or_path: str, + label_threshold: float = 0.9, max_window: Optional[int] = None, **kwargs, ) -> None: @@ -106,6 +108,7 @@ def __init__( self.save_hyperparameters() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self.label_threshold = label_threshold self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() self.num_special_tokens_before = len(self._get_special_tokens_before_input()) @@ -274,16 +277,16 @@ def configure_model_metric(self, stage: str) -> MetricCollection: ), } ), - prepare_function=_get_labels, + prepare_function=partial(_get_labels, label_threshold=self.label_threshold), ), } ) def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: - label_ids = model_output["labels"].detach().cpu().tolist() + is_similar = (model_output["scores"] > self.label_threshold).detach().cpu().tolist() scores = model_output["scores"].detach().cpu().tolist() result: List[TaskOutputType] = [ - {"is_valid": label_id != 0, "score": prob} for label_id, prob in zip(label_ids, scores) + {"is_similar": is_sim, "score": prob} for is_sim, prob in zip(is_similar, scores) ] return result @@ -292,7 +295,7 @@ def create_annotations_from_output( task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], task_output: TaskOutputType, ) -> Iterator[Tuple[str, Annotation]]: - if task_output["is_valid"]: + if task_output["is_similar"]: score = task_output["score"] new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score) yield "binary_coref_relations", new_coref_rel diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py index bef436d68..3e6871a11 100644 --- a/tests/models/test_sequence_pair_similarity_model_with_pooler.py +++ b/tests/models/test_sequence_pair_similarity_model_with_pooler.py @@ -86,7 +86,6 @@ def model() -> SequencePairSimilarityModelWithPooler: torch.manual_seed(42) result = SequencePairSimilarityModelWithPooler( model_name_or_path="prajjwal1/bert-tiny", - label_threshold=0.5, ) return result @@ -169,12 +168,7 @@ def test_forward_logits(model_output, inputs): def test_decode(model, model_output, inputs): decoded = model.decode(inputs=inputs, outputs=model_output) assert isinstance(decoded, dict) - assert set(decoded) == {"labels", "scores"} - labels = decoded["labels"] - torch.testing.assert_close( - labels, - torch.tensor([1, 1, 1, 1]), - ) + assert set(decoded) == {"scores"} scores = decoded["scores"] torch.testing.assert_close( scores, diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index bb3c5b0f0..010f9152f 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -280,7 +280,6 @@ def test_collate(batch, taskmodule): @pytest.fixture(scope="module") def unbatched_output(taskmodule): model_output = { - "labels": torch.tensor([0, 1]), "scores": torch.tensor([0.5338148474693298, 0.9866107940673828]), } return taskmodule.unbatch_output(model_output=model_output) @@ -289,8 +288,8 @@ def unbatched_output(taskmodule): def test_unbatch_output(unbatched_output, taskmodule): assert len(unbatched_output) == 2 assert unbatched_output == [ - {"is_valid": False, "score": 0.5338148474693298}, - {"is_valid": True, "score": 0.9866107702255249}, + {"is_similar": False, "score": 0.5338148474693298}, + {"is_similar": True, "score": 0.9866107702255249}, ] @@ -346,7 +345,6 @@ def test_configure_metric(taskmodule, batch): # targets = batch[1] targets = { - "labels": torch.tensor([0, 1, 0, 0]), "scores": torch.tensor([0.0, 1.0, 0.0, 0.0]), } metric.update(targets, targets) @@ -385,8 +383,7 @@ def test_configure_metric(taskmodule, batch): # torch.rand_like(targets) random_targets = { - "labels": torch.tensor([0, 0, 0, 1]), - "scores": torch.tensor([0.2703, 0.6812, 0.2582, 0.8030]), + "scores": torch.tensor([0.2703, 0.6812, 0.2582, 0.9030]), } metric.update(random_targets, targets) state = get_metric_state(metric) @@ -395,7 +392,7 @@ def test_configure_metric(taskmodule, batch): { "continuous/auroc/preds": [ tensor([0.0, 1.0, 0.0, 0.0]), - tensor([0.2703, 0.6812, 0.2582, 0.8030]), + tensor([0.2703, 0.6812, 0.2582, 0.9030]), ], "continuous/auroc/target": [ tensor([0.0, 1.0, 0.0, 0.0]), From 7c68696ac6d6f7efb1502fe5cdfcd5c2648f00c1 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 15:45:42 +0200 Subject: [PATCH 45/49] rename parameter "label_threshold" to "similarity_threshold" --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index b8bc5f1a6..a156ebf5d 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -100,7 +100,7 @@ class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): def __init__( self, tokenizer_name_or_path: str, - label_threshold: float = 0.9, + similarity_threshold: float = 0.9, max_window: Optional[int] = None, **kwargs, ) -> None: @@ -108,7 +108,7 @@ def __init__( self.save_hyperparameters() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - self.label_threshold = label_threshold + self.similarity_threshold = similarity_threshold self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() self.num_special_tokens_before = len(self._get_special_tokens_before_input()) @@ -277,13 +277,15 @@ def configure_model_metric(self, stage: str) -> MetricCollection: ), } ), - prepare_function=partial(_get_labels, label_threshold=self.label_threshold), + prepare_function=partial( + _get_labels, label_threshold=self.similarity_threshold + ), ), } ) def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: - is_similar = (model_output["scores"] > self.label_threshold).detach().cpu().tolist() + is_similar = (model_output["scores"] > self.similarity_threshold).detach().cpu().tolist() scores = model_output["scores"].detach().cpu().tolist() result: List[TaskOutputType] = [ {"is_similar": is_sim, "score": prob} for is_sim, prob in zip(is_similar, scores) From 049e69e3d33d0530181706ac161014e8094caa29 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 16:34:18 +0200 Subject: [PATCH 46/49] add metric: avg-P (BinaryAveragePrecision) --- .../taskmodules/cross_text_binary_coref.py | 16 ++++++++++++++-- .../taskmodules/test_cross_text_binary_coref.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index a156ebf5d..2ceb890cb 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -21,7 +21,12 @@ from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.utils.window import get_window_around_slice from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection -from torchmetrics.classification import BinaryAUROC +from torchmetrics.classification import ( + BinaryAUROC, + BinaryAveragePrecision, + BinaryPrecisionRecallCurve, + BinaryROC, +) from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias @@ -262,7 +267,14 @@ def configure_model_metric(self, stage: str) -> MetricCollection: return MetricCollection( metrics={ "continuous": WrappedMetricWithPrepareFunction( - metric=MetricCollection({"auroc": BinaryAUROC(thresholds=None)}), + metric=MetricCollection( + { + "auroc": BinaryAUROC(), + "avg-P": BinaryAveragePrecision(validate_args=False), + # "roc": BinaryROC(validate_args=False), + # "PRCurve": BinaryPrecisionRecallCurve(validate_args=False), + } + ), prepare_function=_get_scores, ), "discrete": WrappedMetricWithPrepareFunction( diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 010f9152f..b904d5c95 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -328,6 +328,8 @@ def test_configure_metric(taskmodule, batch): { "continuous/auroc/preds": [], "continuous/auroc/target": [], + "continuous/avg-P/preds": [], + "continuous/avg-P/target": [], "discrete/f1_per_label/tp": tensor([0, 0]), "discrete/f1_per_label/fp": tensor([0, 0]), "discrete/f1_per_label/tn": tensor([0, 0]), @@ -355,6 +357,8 @@ def test_configure_metric(taskmodule, batch): { "continuous/auroc/preds": [tensor([0.0, 1.0, 0.0, 0.0])], "continuous/auroc/target": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/avg-P/preds": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/avg-P/target": [tensor([0.0, 1.0, 0.0, 0.0])], "discrete/f1_per_label/tp": tensor([3, 1]), "discrete/f1_per_label/fp": tensor([0, 0]), "discrete/f1_per_label/tn": tensor([1, 3]), @@ -374,6 +378,7 @@ def test_configure_metric(taskmodule, batch): metric.compute(), { "auroc": tensor(1.0), + "avg-P": tensor(1.0), "no_relation/f1": tensor(1.0), "coref/f1": tensor(1.0), "macro/f1": tensor(1.0), @@ -398,6 +403,14 @@ def test_configure_metric(taskmodule, batch): tensor([0.0, 1.0, 0.0, 0.0]), tensor([0.0, 1.0, 0.0, 0.0]), ], + "continuous/avg-P/preds": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.2703, 0.6812, 0.2582, 0.9030]), + ], + "continuous/avg-P/target": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.0, 1.0, 0.0, 0.0]), + ], "discrete/f1_per_label/tp": tensor([5, 1]), "discrete/f1_per_label/fp": tensor([1, 1]), "discrete/f1_per_label/tn": tensor([1, 5]), @@ -417,6 +430,7 @@ def test_configure_metric(taskmodule, batch): metric.compute(), { "auroc": tensor(0.916667), + "avg-P": tensor(0.833333), "no_relation/f1": tensor(0.833333), "coref/f1": tensor(0.500000), "macro/f1": tensor(0.666667), From a30fc2f0eb856227e03a7046f410962f0f4ccf29 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 16:54:22 +0200 Subject: [PATCH 47/49] use BinaryF1Score instead of MultiClassF1 (micro/macro/per-label) --- .../taskmodules/cross_text_binary_coref.py | 26 +------- .../test_cross_text_binary_coref.py | 66 ++++--------------- 2 files changed, 17 insertions(+), 75 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 2ceb890cb..9258edd55 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -20,10 +20,12 @@ from pytorch_ie.annotations import Span from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.utils.window import get_window_around_slice +from torch.nn.functional import threshold from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection from torchmetrics.classification import ( BinaryAUROC, BinaryAveragePrecision, + BinaryF1Score, BinaryPrecisionRecallCurve, BinaryROC, ) @@ -257,13 +259,6 @@ def collate( return inputs, targets def configure_model_metric(self, stage: str) -> MetricCollection: - # we use the length of label_to_id because that contains the none_label (in contrast to labels) - labels = ["no_relation", "coref"] - common_metric_kwargs = { - "num_classes": len(labels), - "task": "multiclass", - } - return MetricCollection( metrics={ "continuous": WrappedMetricWithPrepareFunction( @@ -273,26 +268,11 @@ def configure_model_metric(self, stage: str) -> MetricCollection: "avg-P": BinaryAveragePrecision(validate_args=False), # "roc": BinaryROC(validate_args=False), # "PRCurve": BinaryPrecisionRecallCurve(validate_args=False), + "f1": BinaryF1Score(threshold=self.similarity_threshold), } ), prepare_function=_get_scores, ), - "discrete": WrappedMetricWithPrepareFunction( - metric=MetricCollection( - { - "micro/f1": F1Score(average="micro", **common_metric_kwargs), - "macro/f1": F1Score(average="macro", **common_metric_kwargs), - "f1_per_label": ClasswiseWrapper( - F1Score(average=None, **common_metric_kwargs), - labels=labels, - postfix="/f1", - ), - } - ), - prepare_function=partial( - _get_labels, label_threshold=self.similarity_threshold - ), - ), } ) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index b904d5c95..26c10bc45 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -330,18 +330,10 @@ def test_configure_metric(taskmodule, batch): "continuous/auroc/target": [], "continuous/avg-P/preds": [], "continuous/avg-P/target": [], - "discrete/f1_per_label/tp": tensor([0, 0]), - "discrete/f1_per_label/fp": tensor([0, 0]), - "discrete/f1_per_label/tn": tensor([0, 0]), - "discrete/f1_per_label/fn": tensor([0, 0]), - "discrete/macro/f1/tp": tensor([0, 0]), - "discrete/macro/f1/fp": tensor([0, 0]), - "discrete/macro/f1/tn": tensor([0, 0]), - "discrete/macro/f1/fn": tensor([0, 0]), - "discrete/micro/f1/tp": tensor([0]), - "discrete/micro/f1/fp": tensor([0]), - "discrete/micro/f1/tn": tensor([0]), - "discrete/micro/f1/fn": tensor([0]), + "continuous/f1/fn": tensor([0]), + "continuous/f1/fp": tensor([0]), + "continuous/f1/tn": tensor([0]), + "continuous/f1/tp": tensor([0]), }, ) @@ -359,31 +351,16 @@ def test_configure_metric(taskmodule, batch): "continuous/auroc/target": [tensor([0.0, 1.0, 0.0, 0.0])], "continuous/avg-P/preds": [tensor([0.0, 1.0, 0.0, 0.0])], "continuous/avg-P/target": [tensor([0.0, 1.0, 0.0, 0.0])], - "discrete/f1_per_label/tp": tensor([3, 1]), - "discrete/f1_per_label/fp": tensor([0, 0]), - "discrete/f1_per_label/tn": tensor([1, 3]), - "discrete/f1_per_label/fn": tensor([0, 0]), - "discrete/macro/f1/tp": tensor([3, 1]), - "discrete/macro/f1/fp": tensor([0, 0]), - "discrete/macro/f1/tn": tensor([1, 3]), - "discrete/macro/f1/fn": tensor([0, 0]), - "discrete/micro/f1/tp": tensor([4]), - "discrete/micro/f1/fp": tensor([0]), - "discrete/micro/f1/tn": tensor([4]), - "discrete/micro/f1/fn": tensor([0]), + "continuous/f1/tp": tensor([1]), + "continuous/f1/fp": tensor([0]), + "continuous/f1/tn": tensor([3]), + "continuous/f1/fn": tensor([0]), }, ) torch.testing.assert_close( metric.compute(), - { - "auroc": tensor(1.0), - "avg-P": tensor(1.0), - "no_relation/f1": tensor(1.0), - "coref/f1": tensor(1.0), - "macro/f1": tensor(1.0), - "micro/f1": tensor(1.0), - }, + {"auroc": tensor(1.0), "avg-P": tensor(1.0), "f1": tensor(1.0)}, ) # torch.rand_like(targets) @@ -411,29 +388,14 @@ def test_configure_metric(taskmodule, batch): tensor([0.0, 1.0, 0.0, 0.0]), tensor([0.0, 1.0, 0.0, 0.0]), ], - "discrete/f1_per_label/tp": tensor([5, 1]), - "discrete/f1_per_label/fp": tensor([1, 1]), - "discrete/f1_per_label/tn": tensor([1, 5]), - "discrete/f1_per_label/fn": tensor([1, 1]), - "discrete/macro/f1/tp": tensor([5, 1]), - "discrete/macro/f1/fp": tensor([1, 1]), - "discrete/macro/f1/tn": tensor([1, 5]), - "discrete/macro/f1/fn": tensor([1, 1]), - "discrete/micro/f1/tp": tensor([6]), - "discrete/micro/f1/fp": tensor([2]), - "discrete/micro/f1/tn": tensor([6]), - "discrete/micro/f1/fn": tensor([2]), + "continuous/f1/tp": tensor([1]), + "continuous/f1/fp": tensor([1]), + "continuous/f1/tn": tensor([5]), + "continuous/f1/fn": tensor([1]), }, ) torch.testing.assert_close( metric.compute(), - { - "auroc": tensor(0.916667), - "avg-P": tensor(0.833333), - "no_relation/f1": tensor(0.833333), - "coref/f1": tensor(0.500000), - "macro/f1": tensor(0.666667), - "micro/f1": tensor(0.750000), - }, + {"auroc": tensor(0.916667), "avg-P": tensor(0.916667), "f1": tensor(0.500000)}, ) From 904fac9b9cc04ff7dc8a73dd782a6bfdd02a524b Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 18:01:46 +0200 Subject: [PATCH 48/49] fix test --- tests/taskmodules/test_cross_text_binary_coref.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index 26c10bc45..5b16d2af1 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -397,5 +397,5 @@ def test_configure_metric(taskmodule, batch): torch.testing.assert_close( metric.compute(), - {"auroc": tensor(0.916667), "avg-P": tensor(0.916667), "f1": tensor(0.500000)}, + {"auroc": tensor(0.91666663), "avg-P": tensor(0.83333337), "f1": tensor(0.50000000)}, ) From c76caca93858b0a67689b2e5402c022813dd2841 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 16 Sep 2024 18:48:18 +0200 Subject: [PATCH 49/49] cleanup --- src/pie_modules/taskmodules/cross_text_binary_coref.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 9258edd55..f40a38077 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -1,6 +1,5 @@ import copy import logging -from functools import partial from typing import ( Any, Dict, @@ -20,14 +19,11 @@ from pytorch_ie.annotations import Span from pytorch_ie.core import TaskEncoding, TaskModule from pytorch_ie.utils.window import get_window_around_slice -from torch.nn.functional import threshold -from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection +from torchmetrics import MetricCollection from torchmetrics.classification import ( BinaryAUROC, BinaryAveragePrecision, BinaryF1Score, - BinaryPrecisionRecallCurve, - BinaryROC, ) from transformers import AutoTokenizer, BatchEncoding from typing_extensions import TypeAlias