diff --git a/README.md b/README.md index d6576146f..dac0027c0 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Available models: - [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py) - [SequenceClassificationModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py) +- [SequencePairSimilarityModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py) - [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py) - [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py) - [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py) @@ -25,6 +26,7 @@ Available models: Available taskmodules: - [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py) +- [CrossTextBinaryCorefTaskModule](src/pie_modules/taskmodules/cross_text_binary_coref.py) - [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py) - [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py) - [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py) diff --git a/src/pie_modules/annotations.py b/src/pie_modules/annotations.py index 09a31bbee..208c2d1b9 100644 --- a/src/pie_modules/annotations.py +++ b/src/pie_modules/annotations.py @@ -63,3 +63,8 @@ class GenerativeAnswer(AnnotationWithText): score: Optional[float] = dataclasses.field(default=None, compare=False) question: Optional[Question] = None + + +@dataclasses.dataclass(eq=True, frozen=True) +class BinaryCorefRelation(BinaryRelation): + label: str = "coref" diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py new file mode 100644 index 000000000..756619eda --- /dev/null +++ b/src/pie_modules/document/processing/text_pair.py @@ -0,0 +1,254 @@ +import copy +import logging +import random +from collections import defaultdict +from collections.abc import Iterator +from itertools import chain +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar + +from pytorch_ie.annotations import LabeledSpan, Span +from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansAndBinaryRelations, + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) +from tqdm import tqdm + +from pie_modules.documents import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) +from pie_modules.utils.span import are_nested + +logger = logging.getLogger(__name__) + +S = TypeVar("S", bound=Span) +S2 = TypeVar("S2", bound=Span) + + +def _span2partition_mapping(spans: Iterable[S], partitions: Iterable[S2]) -> Dict[S, S2]: + result = {} + for span in spans: + for partition in partitions: + if are_nested( + start_end=(span.start, span.end), other_start_end=(partition.start, partition.end) + ): + result[span] = partition + break + return result + + +def _span_copy_shifted(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +def _construct_text_pair_coref_documents_from_partitions_via_relations( + document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, relation_label: str +) -> List[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: + span2partition = _span2partition_mapping( + spans=document.labeled_spans, partitions=document.labeled_partitions + ) + partition2spans = defaultdict(list) + for span, partition in span2partition.items(): + partition2spans[partition].append(span) + + texts2docs_and_span_mappings: Dict[ + Tuple[str, str], + Tuple[ + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, + Dict[LabeledSpan, LabeledSpan], + Dict[LabeledSpan, LabeledSpan], + ], + ] = dict() + result = [] + for rel in document.binary_relations: + if rel.label != relation_label: + continue + + if rel.head not in span2partition: + raise ValueError(f"head not in any partition: {rel.head}") + head_partition = span2partition[rel.head] + text = document.text[head_partition.start : head_partition.end] + + if rel.tail not in span2partition: + raise ValueError(f"tail not in any partition: {rel.tail}") + tail_partition = span2partition[rel.tail] + text_pair = document.text[tail_partition.start : tail_partition.end] + + if (text, text_pair) in texts2docs_and_span_mappings: + new_doc, head_spans_mapping, tail_spans_mapping = texts2docs_and_span_mappings[ + (text, text_pair) + ] + else: + if document.id is not None: + doc_id = ( + f"{document.id}[{head_partition.start}:{head_partition.end}]" + f"+{document.id}[{tail_partition.start}:{tail_partition.end}]" + ) + else: + doc_id = None + new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id=doc_id, text=text, text_pair=text_pair + ) + + head_spans_mapping = { + span: _span_copy_shifted(span=span, offset=-head_partition.start) + for span in partition2spans[head_partition] + } + new_doc.labeled_spans.extend(head_spans_mapping.values()) + + tail_spans_mapping = { + span: _span_copy_shifted(span=span, offset=-tail_partition.start) + for span in partition2spans[tail_partition] + } + new_doc.labeled_spans_pair.extend(tail_spans_mapping.values()) + + texts2docs_and_span_mappings[(text, text_pair)] = ( + new_doc, + head_spans_mapping, + tail_spans_mapping, + ) + result.append(new_doc) + + coref_rel = BinaryCorefRelation( + head=head_spans_mapping[rel.head], tail=tail_spans_mapping[rel.tail], score=1.0 + ) + new_doc.binary_coref_relations.append(coref_rel) + + return result + + +def construct_text_pair_coref_documents_from_partitions_via_relations( + documents: Iterable[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions], **kwargs +) -> Iterator[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: + for doc in documents: + yield from _construct_text_pair_coref_documents_from_partitions_via_relations( + document=doc, **kwargs + ) + + +def shift_span(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +def construct_text_document_from_text_pair_coref_document( + document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, + glue_text: str, + no_relation_label: str, + relation_label_mapping: Optional[Dict[str, str]] = None, +) -> TextDocumentWithLabeledSpansAndBinaryRelations: + if document.text == document.text_pair: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text + ) + old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + for old_span in chain(document.labeled_spans, document.labeled_spans_pair): + new_span = old_span.copy() + # when detaching / copying the span, it may be the same as a previous span from the other + new_span = new2new_spans.get(new_span, new_span) + new2new_spans[new_span] = new_span + old2new_spans[old_span] = new_span + else: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + text=document.text + glue_text + document.text_pair, + id=document.id, + metadata=copy.deepcopy(document.metadata), + ) + old2new_spans = {} + old2new_spans.update({span: span.copy() for span in document.labeled_spans}) + offset = len(document.text) + len(glue_text) + old2new_spans.update( + {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} + ) + + # sort to make order deterministic + new_doc.labeled_spans.extend( + sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) + ) + for old_rel in document.binary_coref_relations: + label = old_rel.label if old_rel.score > 0.0 else no_relation_label + if relation_label_mapping is not None: + label = relation_label_mapping.get(label, label) + new_rel = old_rel.copy( + head=old2new_spans[old_rel.head], + tail=old2new_spans[old_rel.tail], + label=label, + score=1.0, + ) + new_doc.binary_relations.append(new_rel) + + return new_doc + + +def add_negative_coref_relations( + documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], + downsampling_factor: Optional[float] = None, +) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: + positive_tuples = defaultdict(set) + text2spans = defaultdict(set) + for doc in documents: + for labeled_span in doc.labeled_spans: + text2spans[doc.text].add(labeled_span.copy()) + for labeled_span in doc.labeled_spans_pair: + text2spans[doc.text_pair].add(labeled_span.copy()) + + for coref in doc.binary_coref_relations: + positive_tuples[(doc.text, doc.text_pair)].add((coref.head.copy(), coref.tail.copy())) + positive_tuples[(doc.text_pair, doc.text)].add((coref.tail.copy(), coref.head.copy())) + + new_docs = [] + new_rels2new_docs = {} + positive_rels = [] + negative_rels = [] + for text in tqdm(sorted(text2spans)): + for text_pair in sorted(text2spans): + current_positives = positive_tuples.get((text, text_pair), set()) + new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + text=text, text_pair=text_pair + ) + new_doc.labeled_spans.extend(labeled_span.copy() for labeled_span in text2spans[text]) + new_doc.labeled_spans_pair.extend( + labeled_span.copy() for labeled_span in text2spans[text_pair] + ) + for s in sorted(new_doc.labeled_spans): + for s_p in sorted(new_doc.labeled_spans_pair): + # exclude relations to itself + if text == text_pair and s.copy() == s_p.copy(): + continue + if s.label != s_p.label: + continue + score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0 + new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score) + # new_doc.binary_coref_relations.append(new_coref_rel) + new_rels2new_docs[new_coref_rel] = new_doc + if score > 0.0: + positive_rels.append(new_coref_rel) + else: + negative_rels.append(new_coref_rel) + new_docs.append(new_doc) + + for rel in positive_rels: + new_rels2new_docs[rel].binary_coref_relations.append(rel) + + # Downsampling of negatives. This requires positive instances! + if downsampling_factor is not None: + if len(positive_rels) == 0: + raise ValueError( + f"downsampling [factor={downsampling_factor}] is enabled, " + f"but no positive relations are available to calculate max_num_negative" + ) + + max_num_negative = int(len(positive_rels) * downsampling_factor) + if max_num_negative == 0: + logger.warning( + f"downsampling with factor={downsampling_factor} and number of " + f"positive relations={len(positive_rels)} does not produce any negatives" + ) + else: + random.shuffle(negative_rels) + negative_rels = negative_rels[:max_num_negative] + for rel in negative_rels: + new_rels2new_docs[rel].binary_coref_relations.append(rel) + + docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0] + return docs_with_rels diff --git a/src/pie_modules/documents.py b/src/pie_modules/documents.py index 4dae18b78..3de89e850 100644 --- a/src/pie_modules/documents.py +++ b/src/pie_modules/documents.py @@ -27,6 +27,7 @@ from pie_modules.annotations import ( AbstractiveSummary, + BinaryCorefRelation, BinaryRelation, ExtractiveAnswer, GenerativeAnswer, @@ -151,3 +152,63 @@ class TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions( TokenDocumentWithLabeledMultiSpansAndBinaryRelations, ): pass + + +@dataclasses.dataclass +class WithTextPair: + text_pair: str + + +@dataclasses.dataclass +class WithLabeledSpansPair(WithTextPair): + labeled_spans_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") + + +@dataclasses.dataclass +class WithLabeledPartitionsPair(WithTextPair): + labeled_partitions_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair") + + +@dataclasses.dataclass +class TextPairBasedDocument(TextBasedDocument, WithTextPair): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledPartitions( + WithLabeledPartitionsPair, TextPairBasedDocument, TextDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpans( + WithLabeledSpansPair, TextPairBasedDocument, TextDocumentWithLabeledSpans +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansAndLabeledPartitions( + TextPairDocumentWithLabeledPartitions, + TextPairDocumentWithLabeledSpans, + TextDocumentWithLabeledSpansAndLabeledPartitions, +): + pass + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + TextPairDocumentWithLabeledSpans, TextDocumentWithLabeledSpans +): + binary_coref_relations: AnnotationLayer[BinaryCorefRelation] = annotation_field( + targets=["labeled_spans", "labeled_spans_pair"] + ) + + +@dataclasses.dataclass +class TextPairDocumentWithLabeledSpansSimilarityRelationsAndLabeledPartitions( + TextPairDocumentWithLabeledSpansAndLabeledPartitions, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +): + pass diff --git a/src/pie_modules/models/__init__.py b/src/pie_modules/models/__init__.py index f64038f80..df8f4a035 100644 --- a/src/pie_modules/models/__init__.py +++ b/src/pie_modules/models/__init__.py @@ -1,4 +1,7 @@ -from .sequence_classification_with_pooler import SequenceClassificationModelWithPooler +from .sequence_classification_with_pooler import ( + SequenceClassificationModelWithPooler, + SequencePairSimilarityModelWithPooler, +) from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel from .simple_generative import SimpleGenerativeModel from .simple_sequence_classification import SimpleSequenceClassificationModel diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index 98a6965da..cc5c2ea39 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -266,3 +266,77 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: probabilities = torch.sigmoid(outputs.logits) labels = (probabilities > self.multi_label_threshold).to(torch.long) return {"labels": labels, "probabilities": probabilities} + + +@PyTorchIEModel.register() +class SequencePairSimilarityModelWithPooler( + SequenceClassificationModelWithPoolerBase, +): + """A span pair similarity model to detect of two spans occurring in different texts are + similar. It uses an encoder to independently calculate contextualized embeddings of both texts, + then uses a pooler to get representations of the spans and, finally, calculates the cosine to + get the similarity scores. + + Args: + label_threshold: The threshold above which score the spans are considered as similar. + pooler: The pooler identifier or config, see :func:`get_pooler_and_output_size` for details. + Defaults to "mention_pooling" (max pooling over the span token embeddings). + **kwargs + """ + + def __init__( + self, + pooler: Optional[Union[Dict[str, Any], str]] = None, + **kwargs, + ): + if pooler is None: + # use (max) mention pooling per default + pooler = {"type": "mention_pooling", "num_indices": 1} + super().__init__(pooler=pooler, **kwargs) + + def setup_classifier( + self, pooler_output_dim: int + ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: + return torch.nn.functional.cosine_similarity + + def setup_loss_fct(self) -> Callable: + return nn.BCELoss() + + def forward( + self, + inputs: InputType, + targets: Optional[TargetType] = None, + return_hidden_states: bool = False, + ) -> OutputType: + sanitized_inputs = separate_arguments_by_prefix( + # Note that the order of the prefixes is important because one is a prefix of the other, + # so we need to start with the longer! + arguments=inputs, + prefixes=["pooler_pair_", "pooler_"], + ) + + pooled_output = self.get_pooled_output( + model_inputs=sanitized_inputs["remaining"]["encoding"], + pooler_inputs=sanitized_inputs["pooler_"], + ) + pooled_output_pair = self.get_pooled_output( + model_inputs=sanitized_inputs["remaining"]["encoding_pair"], + pooler_inputs=sanitized_inputs["pooler_pair_"], + ) + + logits = self.classifier(pooled_output, pooled_output_pair) + + result = {"logits": logits} + if targets is not None: + labels = targets["scores"] + loss = self.loss_fct(logits, labels) + result["loss"] = loss + if return_hidden_states: + raise NotImplementedError("return_hidden_states is not yet implemented") + + return SequenceClassifierOutput(**result) + + def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: + # probabilities = torch.sigmoid(outputs.logits) + scores = outputs.logits + return {"scores": scores} diff --git a/src/pie_modules/taskmodules/__init__.py b/src/pie_modules/taskmodules/__init__.py index 7fdedab4c..46d3766ae 100644 --- a/src/pie_modules/taskmodules/__init__.py +++ b/src/pie_modules/taskmodules/__init__.py @@ -1,3 +1,4 @@ +from .cross_text_binary_coref import CrossTextBinaryCorefTaskModule from .extractive_question_answering import ExtractiveQuestionAnsweringTaskModule from .labeled_span_extraction_by_token_classification import ( LabeledSpanExtractionByTokenClassificationTaskModule, diff --git a/src/pie_modules/taskmodules/common/mixins.py b/src/pie_modules/taskmodules/common/mixins.py index 876d2f9cc..4a6f647d1 100644 --- a/src/pie_modules/taskmodules/common/mixins.py +++ b/src/pie_modules/taskmodules/common/mixins.py @@ -185,7 +185,10 @@ def finalize_statistics(self): else: raise ValueError(f"unknown key: {key}") for rel in rels_set: - self.increase_counter(key=(key, rel.label)) + # Set "no_relation" as label when the score is zero. We encode negative relations + # in such a way in the case of multi-label or binary (similarity for coref). + label = rel.label if rel.score > 0 else "no_relation" + self.increase_counter(key=(key, label)) for rel in skipped_other: self.increase_counter(key=("skipped_other", rel.label)) diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py new file mode 100644 index 000000000..f40a38077 --- /dev/null +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -0,0 +1,291 @@ +import copy +import logging +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypedDict, + TypeVar, + Union, +) + +import torch +from pytorch_ie import Annotation +from pytorch_ie.annotations import Span +from pytorch_ie.core import TaskEncoding, TaskModule +from pytorch_ie.utils.window import get_window_around_slice +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + BinaryAUROC, + BinaryAveragePrecision, + BinaryF1Score, +) +from transformers import AutoTokenizer, BatchEncoding +from typing_extensions import TypeAlias + +from pie_modules.documents import ( + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) +from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin +from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction +from pie_modules.utils import list_of_dicts2dict_of_lists +from pie_modules.utils.tokenization import ( + SpanNotAlignedWithTokenException, + get_aligned_token_span, +) + +logger = logging.getLogger(__name__) + +InputEncodingType: TypeAlias = Dict[str, Any] +TargetEncodingType: TypeAlias = Sequence[float] +DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations + +TaskEncodingType: TypeAlias = TaskEncoding[ + DocumentType, + InputEncodingType, + TargetEncodingType, +] + + +class TaskOutputType(TypedDict, total=False): + score: float + is_similar: bool + + +ModelInputType: TypeAlias = Dict[str, torch.Tensor] +ModelTargetType: TypeAlias = Dict[str, torch.Tensor] +ModelOutputType: TypeAlias = Dict[str, torch.Tensor] + +TaskModuleType: TypeAlias = TaskModule[ + # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput + DocumentType, + InputEncodingType, + TargetEncodingType, + Tuple[ModelInputType, Optional[ModelTargetType]], + ModelTargetType, + TaskOutputType, +] + + +class SpanDoesNotFitIntoAvailableWindow(Exception): + def __init__(self, span): + self.span = span + + +def _get_labels(model_output: ModelTargetType, label_threshold: float) -> torch.Tensor: + return (model_output["scores"] > label_threshold).to(torch.int) + + +def _get_scores(model_output: ModelTargetType) -> torch.Tensor: + return model_output["scores"] + + +S = TypeVar("S", bound=Span) + + +def shift_span(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +@TaskModule.register() +class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): + """This taskmodule processes documents of type + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a + SequencePairSimilarityModelWithPooler.""" + + DOCUMENT_TYPE = DocumentType + + def __init__( + self, + tokenizer_name_or_path: str, + similarity_threshold: float = 0.9, + max_window: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.save_hyperparameters() + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self.similarity_threshold = similarity_threshold + self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length + self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() + self.num_special_tokens_before = len(self._get_special_tokens_before_input()) + + def _get_special_tokens_before_input(self) -> List[int]: + dummy_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=[-1]) + return dummy_ids[: dummy_ids.index(-1)] + + def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): + self.reset_statistics() + result = super().encode(documents=documents, **kwargs) + self.show_statistics() + return result + + def truncate_encoding_around_span( + self, encoding: BatchEncoding, char_span: Span + ) -> Tuple[Dict[str, List[int]], Span]: + input_ids = copy.deepcopy(encoding["input_ids"]) + + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + + # truncate input_ids and shift token_start and token_end + if len(input_ids) > self.available_window: + window_slice = get_window_around_slice( + slice=(token_span.start, token_span.end), + max_window_size=self.available_window, + available_input_length=len(input_ids), + ) + if window_slice is None: + raise SpanDoesNotFitIntoAvailableWindow(span=token_span) + window_start, window_end = window_slice + input_ids = input_ids[window_start:window_end] + token_span = shift_span(token_span, offset=-window_start) + + truncated_encoding = self.tokenizer.prepare_for_model(ids=input_ids) + # shift indices because we added special tokens to the input_ids + token_span = shift_span(token_span, offset=self.num_special_tokens_before) + + return truncated_encoding, token_span + + def encode_input( + self, + document: DocumentType, + is_training: bool = False, + ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: + self.collect_all_relations(kind="available", relations=document.binary_coref_relations) + tokenizer_kwargs = dict( + padding=False, + truncation=False, + add_special_tokens=False, + ) + encoding = self.tokenizer(text=document.text, **tokenizer_kwargs) + encoding_pair = self.tokenizer(text=document.text_pair, **tokenizer_kwargs) + + task_encodings = [] + for coref_rel in document.binary_coref_relations: + # TODO: This can miss instances if both texts are the same. We could check that + # coref_rel.head is in document.labeled_spans (same for the tail), but would this + # slow down the encoding? + if not ( + coref_rel.head.target == document.text + or coref_rel.tail.target == document.text_pair + ): + raise ValueError( + f"It is expected that coref relations go from (head) spans over 'text' " + f"to (tail) spans over 'text_pair', but this is not the case for this " + f"relation (i.e. it points into the other direction): {coref_rel.resolve()}" + ) + try: + current_encoding, token_span = self.truncate_encoding_around_span( + encoding=encoding, char_span=coref_rel.head + ) + current_encoding_pair, token_span_pair = self.truncate_encoding_around_span( + encoding=encoding_pair, char_span=coref_rel.tail + ) + except SpanNotAlignedWithTokenException as e: + logger.warning( + f"Could not get token offsets for argument ({e.span}) of coref relation: " + f"{coref_rel.resolve()}. Skip it." + ) + self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) + continue + except SpanDoesNotFitIntoAvailableWindow as e: + logger.warning( + f"Argument span [{e.span}] does not fit into available token window " + f"({self.available_window}). Skip it." + ) + self.collect_relation( + kind="skipped_span_does_not_fit_into_window", relation=coref_rel + ) + continue + + task_encodings.append( + TaskEncoding( + document=document, + inputs={ + "encoding": current_encoding, + "encoding_pair": current_encoding_pair, + "pooler_start_indices": token_span.start, + "pooler_end_indices": token_span.end, + "pooler_pair_start_indices": token_span_pair.start, + "pooler_pair_end_indices": token_span_pair.end, + }, + metadata={"candidate_annotation": coref_rel}, + ) + ) + self.collect_relation("used", coref_rel) + return task_encodings + + def encode_target( + self, + task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], + ) -> Optional[TargetEncodingType]: + return task_encoding.metadata["candidate_annotation"].score + + def collate( + self, + task_encodings: Sequence[ + TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType] + ], + ) -> Tuple[ModelInputType, Optional[ModelTargetType]]: + inputs_dict = list_of_dicts2dict_of_lists( + [task_encoding.inputs for task_encoding in task_encodings] + ) + + inputs = { + k: self.tokenizer.pad(v, return_tensors="pt").data + if k in ["encoding", "encoding_pair"] + else torch.tensor(v) + for k, v in inputs_dict.items() + } + for k, v in inputs.items(): + if k.startswith("pooler_") and k.endswith("_indices"): + inputs[k] = v.unsqueeze(-1) + + if not task_encodings[0].has_targets: + return inputs, None + targets = { + "scores": torch.tensor([task_encoding.targets for task_encoding in task_encodings]) + } + return inputs, targets + + def configure_model_metric(self, stage: str) -> MetricCollection: + return MetricCollection( + metrics={ + "continuous": WrappedMetricWithPrepareFunction( + metric=MetricCollection( + { + "auroc": BinaryAUROC(), + "avg-P": BinaryAveragePrecision(validate_args=False), + # "roc": BinaryROC(validate_args=False), + # "PRCurve": BinaryPrecisionRecallCurve(validate_args=False), + "f1": BinaryF1Score(threshold=self.similarity_threshold), + } + ), + prepare_function=_get_scores, + ), + } + ) + + def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: + is_similar = (model_output["scores"] > self.similarity_threshold).detach().cpu().tolist() + scores = model_output["scores"].detach().cpu().tolist() + result: List[TaskOutputType] = [ + {"is_similar": is_sim, "score": prob} for is_sim, prob in zip(is_similar, scores) + ] + return result + + def create_annotations_from_output( + self, + task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], + task_output: TaskOutputType, + ) -> Iterator[Tuple[str, Annotation]]: + if task_output["is_similar"]: + score = task_output["score"] + new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score) + yield "binary_coref_relations", new_coref_rel diff --git a/src/pie_modules/utils/tokenization.py b/src/pie_modules/utils/tokenization.py new file mode 100644 index 000000000..addc2fb03 --- /dev/null +++ b/src/pie_modules/utils/tokenization.py @@ -0,0 +1,35 @@ +from typing import TypeVar + +from pytorch_ie.annotations import Span +from transformers import BatchEncoding + +S = TypeVar("S", bound=Span) + + +class SpanNotAlignedWithTokenException(Exception): + def __init__(self, span): + self.span = span + + +def get_aligned_token_span(encoding: BatchEncoding, char_span: S) -> S: + # find the start + token_start = None + token_end_before = None + char_start = None + for idx in range(char_span.start, char_span.end): + token_start = encoding.char_to_token(idx) + if token_start is not None: + char_start = idx + break + + if char_start is None: + raise SpanNotAlignedWithTokenException(span=char_span) + for idx in range(char_span.end - 1, char_start - 1, -1): + token_end_before = encoding.char_to_token(idx) + if token_end_before is not None: + break + + if token_start is None or token_end_before is None: + raise SpanNotAlignedWithTokenException(span=char_span) + + return char_span.copy(start=token_start, end=token_end_before + 1) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py new file mode 100644 index 000000000..dd205ce7d --- /dev/null +++ b/tests/document/processing/test_text_pair.py @@ -0,0 +1,426 @@ +import random +from itertools import chain +from typing import List + +import pytest +from pytorch_ie.annotations import BinaryRelation, LabeledSpan +from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) + +from pie_modules.document.processing.text_pair import ( + add_negative_coref_relations, + construct_text_document_from_text_pair_coref_document, + construct_text_pair_coref_documents_from_partitions_via_relations, +) +from pie_modules.documents import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) + +SENTENCES = [ + "Entity A works at B.", + "And she founded C.", + "Bob loves his cat.", + "She sleeps a lot.", +] + + +@pytest.fixture(scope="module") +def text_documents() -> List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]: + doc1 = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + id="doc1", text=" ".join(SENTENCES[:2]) + ) + # add sentence partitions + doc1.labeled_partitions.append(LabeledSpan(start=0, end=len(SENTENCES[0]), label="sentence")) + doc1.labeled_partitions.append( + LabeledSpan( + start=len(SENTENCES[0]) + 1, + end=len(SENTENCES[0]) + 1 + len(SENTENCES[1]), + label="sentence", + ) + ) + # add spans + doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) + doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) + doc1_sen2_offset = doc1.labeled_partitions[1].start + doc1.labeled_spans.append( + LabeledSpan(start=4 + doc1_sen2_offset, end=7 + doc1_sen2_offset, label="PERSON") + ) + doc1.labeled_spans.append( + LabeledSpan(start=16 + doc1_sen2_offset, end=17 + doc1_sen2_offset, label="COMPANY") + ) + # add relation + doc1.binary_relations.append( + BinaryRelation( + head=doc1.labeled_spans[0], tail=doc1.labeled_spans[2], label="semantically_same" + ) + ) + + doc2 = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + id="doc2", text=" ".join(SENTENCES[2:4]) + ) + # add sentence partitions + doc2.labeled_partitions.append(LabeledSpan(start=0, end=len(SENTENCES[2]), label="sentence")) + doc2.labeled_partitions.append( + LabeledSpan( + start=len(SENTENCES[2]) + 1, + end=len(SENTENCES[2]) + 1 + len(SENTENCES[3]), + label="sentence", + ) + ) + # add spans + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2_sen2_offset = doc2.labeled_partitions[1].start + doc2.labeled_spans.append( + LabeledSpan(start=0 + doc2_sen2_offset, end=3 + doc2_sen2_offset, label="ANIMAL") + ) + # add relation + doc2.binary_relations.append( + BinaryRelation( + head=doc2.labeled_spans[1], tail=doc2.labeled_spans[2], label="semantically_same" + ) + ) + + return [doc1, doc2] + + +def test_simple_text_documents(text_documents): + assert len(text_documents) == 2 + doc = text_documents[0] + # test serialization + doc.copy() + # test sentences + assert doc.labeled_partitions.resolve() == [ + ("sentence", "Entity A works at B."), + ("sentence", "And she founded C."), + ] + # test spans + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Entity A"), + ("COMPANY", "B"), + ("PERSON", "she"), + ("COMPANY", "C"), + ] + # test relation + assert doc.binary_relations.resolve() == [ + ("semantically_same", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + doc = text_documents[1] + # test serialization + doc.copy() + # test sentences + assert doc.labeled_partitions.resolve() == [ + ("sentence", "Bob loves his cat."), + ("sentence", "She sleeps a lot."), + ] + # test spans + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ("ANIMAL", "She"), + ] + # test relation + assert doc.binary_relations.resolve() == [ + ("semantically_same", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + +def test_construct_text_pair_coref_documents_from_partitions_via_relations(text_documents): + all_docs = { + doc.id: doc + for doc in construct_text_pair_coref_documents_from_partitions_via_relations( + documents=text_documents, relation_label="semantically_same" + ) + } + assert set(all_docs) == {"doc2[0:18]+doc2[19:36]", "doc1[0:20]+doc1[21:39]"} + + doc = all_docs["doc2[0:18]+doc2[19:36]"] + assert doc.text == "Bob loves his cat." + assert doc.text_pair == "She sleeps a lot." + assert doc.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] + assert doc.labeled_spans_pair.resolve() == [("ANIMAL", "She")] + assert doc.binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + doc = all_docs["doc1[0:20]+doc1[21:39]"] + assert doc.text == "Entity A works at B." + assert doc.text_pair == "And she founded C." + assert doc.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] + assert doc.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc.binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + +@pytest.fixture(scope="module") +def positive_documents(): + doc1 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Entity A works at B.", text_pair="And she founded C." + ) + doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) + doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) + doc1.labeled_spans_pair.append(LabeledSpan(start=4, end=7, label="PERSON")) + doc1.labeled_spans_pair.append(LabeledSpan(start=16, end=17, label="COMPANY")) + doc1.binary_coref_relations.append( + BinaryCorefRelation(head=doc1.labeled_spans[0], tail=doc1.labeled_spans_pair[0]) + ) + + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + doc2.binary_coref_relations.append( + BinaryCorefRelation(head=doc2.labeled_spans[1], tail=doc2.labeled_spans_pair[0]) + ) + + return [doc1, doc2] + + +def test_positive_documents(positive_documents): + assert len(positive_documents) == 2 + doc1, doc2 = positive_documents + assert doc1.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] + assert doc1.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc1.binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + assert doc2.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] + assert doc2.labeled_spans_pair.resolve() == [("ANIMAL", "She")] + assert doc2.binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + +@pytest.fixture(scope="module") +def positive_and_negative_documents(positive_documents): + docs = list(add_negative_coref_relations(positive_documents)) + return docs + + +def test_construct_negative_documents(positive_and_negative_documents): + assert len(positive_and_negative_documents) == 8 + TEXTS = [ + "Entity A works at B.", + "And she founded C.", + "Bob loves his cat.", + "She sleeps a lot.", + ] + assert all(doc.text in TEXTS for doc in positive_and_negative_documents) + assert all(doc.text_pair in TEXTS for doc in positive_and_negative_documents) + + all_texts = [(doc.text, doc.text_pair) for doc in positive_and_negative_documents] + all_scores = [ + [coref_rel.score for coref_rel in doc.binary_coref_relations] + for doc in positive_and_negative_documents + ] + all_rels_resolved = [ + doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents + ] + + # check number of all relations + all_rels_flat = [ + rel for doc in positive_and_negative_documents for rel in doc.binary_coref_relations + ] + assert len(all_rels_flat) == 10 + # positives + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 6 + + all_rels_and_scores = [ + (texts, list(zip(scores, rels_resolved))) + for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) + ] + + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Bob loves his cat."), + [(0.0, ("coref", (("PERSON", "she"), ("PERSON", "Bob"))))], + ), + ( + ("And she founded C.", "Entity A works at B."), + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + ), + ( + ("Bob loves his cat.", "And she founded C."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], + ), + ( + ("Bob loves his cat.", "Entity A works at B."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + ), + ( + ("Entity A works at B.", "Bob loves his cat."), + [(0.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "Bob"))))], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + + +def _get_all_all_rels_and_scores(docs): + all_texts = [(doc.text, doc.text_pair) for doc in docs] + all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] + all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] + + all_rels_and_scores = [ + (texts, list(zip(scores, rels_resolved))) + for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) + ] + return all_rels_and_scores + + +def test_construct_negative_documents_with_downsampling(positive_documents, caplog): + # set fixed seed because the negatives will get shuffled + random.seed(42) + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) + all_rels_and_scores = _get_all_all_rels_and_scores(docs) + + # check number relations + all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] + # positives + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives (same number positives because downsampling_factor=1.0) + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 4 + + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Entity A works at B."), + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + ), + ( + ("Bob loves his cat.", "And she founded C."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], + ), + ( + ("Bob loves his cat.", "Entity A works at B."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + + # sampling target is too low + caplog.clear() + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) + assert caplog.messages == [ + "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" + ] + # check number relations + all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] + # positives: 2 x number of positives (we add instances with swapped texts) + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 0 + # check actual content + all_rels_and_scores = _get_all_all_rels_and_scores(docs) + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Entity A works at B."), + [(1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [(1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))))], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + + # no positives + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + with pytest.raises(ValueError) as e: + list(add_negative_coref_relations([doc2], downsampling_factor=1.0)) + assert ( + str(e.value) + == "downsampling [factor=1.0] is enabled, but no positive relations are available to calculate " + "max_num_negative" + ) + + +def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): + glue_text = "" + docs = [ + construct_text_document_from_text_pair_coref_document( + doc, + glue_text=glue_text, + no_relation_label="no_relation", + relation_label_mapping={"coref": "semantically_same"}, + ) + for doc in positive_and_negative_documents + ] + assert len(docs) == 8 + doc = docs[0] + assert doc.text == "And she founded C.Bob loves his cat." + assert doc.labeled_spans.resolve() == [ + ("PERSON", "she"), + ("COMPANY", "C"), + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ] + assert doc.binary_relations.resolve() == [ + ("no_relation", (("PERSON", "she"), ("PERSON", "Bob"))) + ] + assert [rel.score for rel in doc.binary_relations] == [1.0] + + doc = docs[4] + assert doc.text == "Bob loves his cat.She sleeps a lot." + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ("ANIMAL", "She"), + ] + assert doc.binary_relations.resolve() == [ + ("semantically_same", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + assert [rel.score for rel in doc.binary_relations] == [1.0] diff --git a/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json b/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json new file mode 100644 index 000000000..cbc258ecd --- /dev/null +++ b/tests/fixtures/taskmodules/cross_text_binary_coref/documents_with_negatives.json @@ -0,0 +1,792 @@ +[ + { + "text_pair": "And she founded C.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": 2545181322977893893, + "tail": -7091027580690283656, + "label": "coref", + "score": 0.0, + "_id": -1763877672186772918 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": 2545181322977893893, + "tail": -177396764231138184, + "label": "coref", + "score": 1.0, + "_id": 5113198133391321397 + }, + { + "head": -2143209897469179365, + "tail": 3188240167591245379, + "label": "coref", + "score": 0.0, + "_id": -734219494647036300 + } + ], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "And she founded C.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "And she founded C.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -7091027580690283656, + "tail": 2545181322977893893, + "label": "coref", + "score": 0.0, + "_id": 4323963091729289163 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -7091027580690283656, + "tail": -177396764231138184, + "label": "coref", + "score": 0.0, + "_id": -4269111567075058761 + } + ], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "Bob loves his cat.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -6613361595321704194, + "tail": 2360667792531975882, + "label": "coref", + "score": 1.0, + "_id": 8198921634551745514 + } + ], + "predictions": [] + } + }, + { + "text_pair": "And she founded C.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -177396764231138184, + "tail": 2545181322977893893, + "label": "coref", + "score": 1.0, + "_id": -4710872194864906092 + }, + { + "head": 3188240167591245379, + "tail": -2143209897469179365, + "label": "coref", + "score": 0.0, + "_id": 2636939255468582059 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": -177396764231138184, + "tail": -7091027580690283656, + "label": "coref", + "score": 0.0, + "_id": -1990964066152094896 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "Entity A works at B.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "And she founded C.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 16, + "end": 17, + "label": "COMPANY", + "score": 1.0, + "_id": -2143209897469179365 + }, + { + "start": 4, + "end": 7, + "label": "PERSON", + "score": 1.0, + "_id": 2545181322977893893 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "Bob loves his cat.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "PERSON", + "score": 1.0, + "_id": -7091027580690283656 + }, + { + "start": 10, + "end": 17, + "label": "ANIMAL", + "score": 1.0, + "_id": -6613361595321704194 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [ + { + "head": 2360667792531975882, + "tail": -6613361595321704194, + "label": "coref", + "score": 1.0, + "_id": -571410837328299027 + } + ], + "predictions": [] + } + }, + { + "text_pair": "Entity A works at B.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 18, + "end": 19, + "label": "COMPANY", + "score": 1.0, + "_id": 3188240167591245379 + }, + { + "start": 0, + "end": 8, + "label": "PERSON", + "score": 1.0, + "_id": -177396764231138184 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + }, + { + "text_pair": "She sleeps a lot.", + "text": "She sleeps a lot.", + "id": null, + "metadata": null, + "labeled_spans": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "labeled_spans_pair": { + "annotations": [ + { + "start": 0, + "end": 3, + "label": "ANIMAL", + "score": 1.0, + "_id": 2360667792531975882 + } + ], + "predictions": [] + }, + "binary_coref_relations": { + "annotations": [], + "predictions": [] + } + } +] diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py new file mode 100644 index 000000000..3e6871a11 --- /dev/null +++ b/tests/models/test_sequence_pair_similarity_model_with_pooler.py @@ -0,0 +1,326 @@ +from typing import Dict + +import pytest +import torch +from pytorch_lightning import Trainer +from torch import LongTensor, tensor +from torch.optim.lr_scheduler import LambdaLR +from transformers.modeling_outputs import SequenceClassifierOutput + +from pie_modules.models import SequencePairSimilarityModelWithPooler +from pie_modules.models.sequence_classification_with_pooler import OutputType +from tests.models import trunc_number + + +@pytest.fixture +def inputs() -> Dict[str, LongTensor]: + result_dict = { + "encoding": { + "input_ids": tensor( + [ + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + [101, 1262, 1131, 1771, 140, 119, 102], + ] + ), + "token_type_ids": tensor( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + "attention_mask": tensor( + [ + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ] + ), + }, + "encoding_pair": { + "input_ids": tensor( + [ + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + [101, 3162, 7871, 1117, 5855, 119, 102], + ] + ), + "token_type_ids": tensor( + [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ] + ), + "attention_mask": tensor( + [ + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ] + ), + }, + "pooler_start_indices": tensor([[2], [2], [4], [4]]), + "pooler_end_indices": tensor([[3], [3], [5], [5]]), + "pooler_pair_start_indices": tensor([[1], [3], [1], [3]]), + "pooler_pair_end_indices": tensor([[2], [5], [2], [5]]), + } + + return result_dict + + +@pytest.fixture +def targets() -> Dict[str, LongTensor]: + return {"scores": tensor([0.0, 0.0, 0.0, 0.0])} + + +@pytest.fixture +def model() -> SequencePairSimilarityModelWithPooler: + torch.manual_seed(42) + result = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + ) + return result + + +def test_model(model): + assert model is not None + named_parameters = dict(model.named_parameters()) + parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} + parameter_means_expected = { + "model.embeddings.word_embeddings.weight": 0.0031152, + "model.embeddings.position_embeddings.weight": 5.5e-05, + "model.embeddings.token_type_embeddings.weight": -0.0015419, + "model.embeddings.LayerNorm.weight": 1.312345, + "model.embeddings.LayerNorm.bias": -0.0294608, + "model.encoder.layer.0.attention.self.query.weight": -0.0003949, + "model.encoder.layer.0.attention.self.query.bias": 0.0185744, + "model.encoder.layer.0.attention.self.key.weight": 0.0003863, + "model.encoder.layer.0.attention.self.key.bias": 0.0020557, + "model.encoder.layer.0.attention.self.value.weight": 4.22e-05, + "model.encoder.layer.0.attention.self.value.bias": 0.0065417, + "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05, + "model.encoder.layer.0.attention.output.dense.bias": 0.0007209, + "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, + "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, + "model.encoder.layer.0.intermediate.dense.weight": -0.0011731, + "model.encoder.layer.0.intermediate.dense.bias": -0.1219958, + "model.encoder.layer.0.output.dense.weight": -0.0002212, + "model.encoder.layer.0.output.dense.bias": -0.0013031, + "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648, + "model.encoder.layer.0.output.LayerNorm.bias": 0.005295, + "model.encoder.layer.1.attention.self.query.weight": -0.0007321, + "model.encoder.layer.1.attention.self.query.bias": -0.0358397, + "model.encoder.layer.1.attention.self.key.weight": 0.0001333, + "model.encoder.layer.1.attention.self.key.bias": 0.0045062, + "model.encoder.layer.1.attention.self.value.weight": 0.0001012, + "model.encoder.layer.1.attention.self.value.bias": -0.0007094, + "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05, + "model.encoder.layer.1.attention.output.dense.bias": 0.0041446, + "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, + "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, + "model.encoder.layer.1.intermediate.dense.weight": -0.001344, + "model.encoder.layer.1.intermediate.dense.bias": -0.1247257, + "model.encoder.layer.1.output.dense.weight": -5.32e-05, + "model.encoder.layer.1.output.dense.bias": 0.000677, + "model.encoder.layer.1.output.LayerNorm.weight": 1.017162, + "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442, + "model.pooler.dense.weight": 0.0001295, + "model.pooler.dense.bias": -0.0052078, + "pooler.missing_embeddings": 0.0812017, + } + assert parameter_means == parameter_means_expected + + +def test_model_pickleable(model): + import pickle + + pickle.dumps(model) + + +@pytest.fixture +def model_output(model, inputs) -> OutputType: + # set seed to make sure the output is deterministic + torch.manual_seed(42) + return model(inputs) + + +def test_forward_logits(model_output, inputs): + assert isinstance(model_output, SequenceClassifierOutput) + + logits = model_output.logits + + torch.testing.assert_close( + logits, + torch.tensor( + [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] + ), + ) + + +def test_decode(model, model_output, inputs): + decoded = model.decode(inputs=inputs, outputs=model_output) + assert isinstance(decoded, dict) + assert set(decoded) == {"scores"} + scores = decoded["scores"] + torch.testing.assert_close( + scores, + torch.tensor( + [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] + ), + ) + + +@pytest.fixture +def batch(inputs, targets): + return inputs, targets + + +def test_training_step(batch, model): + # set the seed to make sure the loss is deterministic + torch.manual_seed(42) + loss = model.training_step(batch, batch_idx=0) + assert loss is not None + torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) + + +def test_validation_step(batch, model): + # set the seed to make sure the loss is deterministic + torch.manual_seed(42) + loss = model.validation_step(batch, batch_idx=0) + assert loss is not None + torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) + + +def test_test_step(batch, model): + # set the seed to make sure the loss is deterministic + torch.manual_seed(42) + loss = model.test_step(batch, batch_idx=0) + assert loss is not None + torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) + + +def test_base_model_named_parameters(model): + base_model_named_parameters = dict(model.base_model_named_parameters()) + assert set(base_model_named_parameters) == { + "model.pooler.dense.bias", + "model.encoder.layer.0.intermediate.dense.weight", + "model.encoder.layer.0.intermediate.dense.bias", + "model.encoder.layer.1.attention.output.dense.weight", + "model.encoder.layer.1.attention.output.LayerNorm.weight", + "model.encoder.layer.1.attention.self.query.weight", + "model.encoder.layer.1.output.dense.weight", + "model.encoder.layer.0.output.dense.bias", + "model.encoder.layer.1.intermediate.dense.bias", + "model.encoder.layer.1.attention.self.value.bias", + "model.encoder.layer.0.attention.output.dense.weight", + "model.encoder.layer.0.attention.self.query.bias", + "model.encoder.layer.0.attention.self.value.bias", + "model.encoder.layer.1.output.dense.bias", + "model.encoder.layer.1.attention.self.query.bias", + "model.encoder.layer.1.attention.output.LayerNorm.bias", + "model.encoder.layer.0.attention.self.query.weight", + "model.encoder.layer.0.attention.output.LayerNorm.bias", + "model.encoder.layer.0.attention.self.key.bias", + "model.encoder.layer.1.intermediate.dense.weight", + "model.encoder.layer.1.output.LayerNorm.bias", + "model.encoder.layer.1.output.LayerNorm.weight", + "model.encoder.layer.0.attention.self.key.weight", + "model.encoder.layer.1.attention.output.dense.bias", + "model.encoder.layer.0.attention.output.dense.bias", + "model.embeddings.LayerNorm.bias", + "model.encoder.layer.0.attention.self.value.weight", + "model.encoder.layer.0.attention.output.LayerNorm.weight", + "model.embeddings.token_type_embeddings.weight", + "model.encoder.layer.0.output.LayerNorm.weight", + "model.embeddings.position_embeddings.weight", + "model.encoder.layer.1.attention.self.key.bias", + "model.embeddings.LayerNorm.weight", + "model.encoder.layer.0.output.LayerNorm.bias", + "model.encoder.layer.1.attention.self.key.weight", + "model.pooler.dense.weight", + "model.encoder.layer.0.output.dense.weight", + "model.embeddings.word_embeddings.weight", + "model.encoder.layer.1.attention.self.value.weight", + } + + +def test_task_named_parameters(model): + task_named_parameters = dict(model.task_named_parameters()) + assert set(task_named_parameters) == { + "pooler.missing_embeddings", + } + + +def test_configure_optimizers_with_warmup(): + model = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + ) + model.trainer = Trainer(max_epochs=10) + optimizers_and_schedulers = model.configure_optimizers() + assert len(optimizers_and_schedulers) == 2 + optimizers, schedulers = optimizers_and_schedulers + assert len(optimizers) == 1 + assert len(schedulers) == 1 + optimizer = optimizers[0] + assert optimizer is not None + assert isinstance(optimizer, torch.optim.AdamW) + assert optimizer.defaults["lr"] == 1e-05 + assert optimizer.defaults["weight_decay"] == 0.01 + assert optimizer.defaults["eps"] == 1e-08 + + scheduler = schedulers[0] + assert isinstance(scheduler, dict) + assert set(scheduler) == {"scheduler", "interval"} + assert isinstance(scheduler["scheduler"], LambdaLR) + + +def test_configure_optimizers_with_task_learning_rate(monkeypatch): + model = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + learning_rate=1e-5, + task_learning_rate=1e-3, + # disable warmup to make sure the scheduler is not added which would set the learning rate + # to 0 + warmup_proportion=0.0, + ) + optimizer = model.configure_optimizers() + assert optimizer is not None + assert isinstance(optimizer, torch.optim.AdamW) + assert len(optimizer.param_groups) == 2 + # base model parameters + param_group = optimizer.param_groups[0] + assert len(param_group["params"]) == 39 + assert param_group["lr"] == 1e-5 + # classifier head parameters - there is just the default embedding (which is not used) + param_group = optimizer.param_groups[1] + assert len(param_group["params"]) == 1 + assert param_group["lr"] == 1e-3 + # ensure that all parameters are covered + assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set( + model.parameters() + ) + + +def test_freeze_base_model(monkeypatch, inputs, targets): + model = SequencePairSimilarityModelWithPooler( + model_name_or_path="prajjwal1/bert-tiny", + freeze_base_model=True, + # disable warmup to make sure the scheduler is not added which would set the learning rate + # to 0 + warmup_proportion=0.0, + ) + base_model_params = [param for name, param in model.base_model_named_parameters()] + task_params = [param for name, param in model.task_named_parameters()] + assert len(base_model_params) + len(task_params) == len(list(model.parameters())) + for param in base_model_params: + assert not param.requires_grad + for param in task_params: + assert param.requires_grad diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py new file mode 100644 index 000000000..5b16d2af1 --- /dev/null +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -0,0 +1,401 @@ +import json +import logging +from typing import Any, Dict, Union + +import pytest +import torch.testing +from pytorch_ie.annotations import LabeledSpan +from torch import tensor +from torchmetrics import Metric, MetricCollection + +from pie_modules.document.processing.text_pair import add_negative_coref_relations +from pie_modules.documents import ( + BinaryCorefRelation, + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, +) +from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule +from pie_modules.utils import flatten_dict, list_of_dicts2dict_of_lists +from tests import FIXTURES_ROOT, _config_to_str + +TOKENIZER_NAME_OR_PATH = "bert-base-cased" +DOC_IDX_WITH_TASK_ENCODINGS = 2 + +CONFIGS = [ + {}, +] +CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} + + +@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) +def config(request): + return CONFIGS_DICT[request.param] + + +@pytest.fixture(scope="module") +def positive_documents(): + doc1 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Entity A works at B.", text_pair="And she founded C." + ) + doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) + doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) + doc1.labeled_spans_pair.append(LabeledSpan(start=4, end=7, label="PERSON")) + doc1.labeled_spans_pair.append(LabeledSpan(start=16, end=17, label="COMPANY")) + doc1.binary_coref_relations.append( + BinaryCorefRelation(head=doc1.labeled_spans[0], tail=doc1.labeled_spans_pair[0]) + ) + + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + doc2.binary_coref_relations.append( + BinaryCorefRelation(head=doc2.labeled_spans[1], tail=doc2.labeled_spans_pair[0]) + ) + + return [doc1, doc2] + + +def test_positive_documents(positive_documents): + assert len(positive_documents) == 2 + doc1, doc2 = positive_documents + assert doc1.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] + assert doc1.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc1.binary_coref_relations.resolve() == [ + ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) + ] + + assert doc2.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] + assert doc2.labeled_spans_pair.resolve() == [("ANIMAL", "She")] + assert doc2.binary_coref_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + + +@pytest.fixture(scope="module") +def unprepared_taskmodule(config): + taskmodule = CrossTextBinaryCorefTaskModule( + tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, **config + ) + assert not taskmodule.is_from_pretrained + + return taskmodule + + +@pytest.fixture(scope="module") +def taskmodule(unprepared_taskmodule, positive_documents): + unprepared_taskmodule.prepare(positive_documents) + return unprepared_taskmodule + + +@pytest.fixture(scope="module") +def documents_with_negatives(taskmodule, positive_documents): + file_name = ( + FIXTURES_ROOT / "taskmodules" / "cross_text_binary_coref" / "documents_with_negatives.json" + ) + + # result = list(add_negative_relations(positive_documents)) + # result_json = [doc.asdict() for doc in result] + # with open(file_name, "w") as f: + # json.dump(result_json, f, indent=2) + + with open(file_name) as f: + result_json = json.load(f) + result = [ + TextPairDocumentWithLabeledSpansAndBinaryCorefRelations.fromdict(doc_json) + for doc_json in result_json + ] + + return result + + +@pytest.fixture(scope="module") +def task_encodings_without_target(taskmodule, documents_with_negatives): + task_encodings = taskmodule.encode_input(documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS]) + return task_encodings + + +def test_encode_input(task_encodings_without_target, taskmodule): + task_encodings = task_encodings_without_target + convert_ids_to_tokens = taskmodule.tokenizer.convert_ids_to_tokens + + inputs_dict = list_of_dicts2dict_of_lists( + [task_encoding.inputs for task_encoding in task_encodings] + ) + tokens = [convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding"]] + tokens_pair = [ + convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding_pair"] + ] + assert tokens == [ + ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], + ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], + ] + assert tokens_pair == [ + ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], + ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], + ] + span_tokens = [ + toks[start:end] + for toks, start, end in zip( + tokens, inputs_dict["pooler_start_indices"], inputs_dict["pooler_end_indices"] + ) + ] + span_tokens_pair = [ + toks[start:end] + for toks, start, end in zip( + tokens_pair, + inputs_dict["pooler_pair_start_indices"], + inputs_dict["pooler_pair_end_indices"], + ) + ] + assert span_tokens == [["she"], ["C"]] + assert span_tokens_pair == [["En", "##ti", "##ty", "A"], ["B"]] + + +def test_encode_target(task_encodings_without_target, taskmodule): + targets = [ + taskmodule.encode_target(task_encoding) for task_encoding in task_encodings_without_target + ] + assert targets == [1.0, 0.0] + + +def test_encode_with_collect_statistics(taskmodule, positive_documents, caplog): + documents_with_negatives = add_negative_coref_relations(positive_documents) + caplog.clear() + with caplog.at_level(logging.INFO): + original_values = taskmodule.collect_statistics + taskmodule.collect_statistics = True + taskmodule.encode(documents_with_negatives, encode_target=True) + taskmodule.collect_statistics = original_values + + assert len(caplog.messages) == 1 + assert ( + caplog.messages[0] == "statistics:\n" + "| | coref | no_relation | all_relations |\n" + "|:----------|--------:|--------------:|----------------:|\n" + "| available | 4 | 6 | 4 |\n" + "| used | 4 | 6 | 4 |\n" + "| used % | 100 | 100 | 100 |" + ) + + +def test_encode_with_windowing(documents_with_negatives, caplog): + tokenizer_name_or_path = "bert-base-cased" + taskmodule = CrossTextBinaryCorefTaskModule( + tokenizer_name_or_path=tokenizer_name_or_path, + max_window=4, + collect_statistics=True, + ) + assert not taskmodule.is_from_pretrained + taskmodule.prepare(documents_with_negatives) + + assert len(documents_with_negatives) == 16 + caplog.clear() + with caplog.at_level(logging.INFO): + task_encodings = taskmodule.encode(documents_with_negatives) + assert len(caplog.messages) > 0 + assert ( + caplog.messages[-1] == "statistics:\n" + "| | coref | no_relation | all_relations |\n" + "|:--------------------------------------|--------:|--------------:|----------------:|\n" + "| available | 4 | 6 | 4 |\n" + "| skipped_span_does_not_fit_into_window | 2 | 2 | 2 |\n" + "| used | 2 | 4 | 2 |\n" + "| used % | 50 | 67 | 50 |" + ) + + assert len(task_encodings) == 6 + for task_encoding in task_encodings: + for k, v in task_encoding.inputs["encoding"].items(): + assert len(v) <= taskmodule.max_window + for k, v in task_encoding.inputs["encoding_pair"].items(): + assert len(v) <= taskmodule.max_window + + +@pytest.fixture(scope="module") +def task_encodings(taskmodule, documents_with_negatives): + return taskmodule.encode( + documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS], encode_target=True + ) + + +@pytest.fixture(scope="module") +def batch(taskmodule, task_encodings): + result = taskmodule.collate(task_encodings) + return result + + +def test_collate(batch, taskmodule): + assert batch is not None + inputs, targets = batch + assert inputs is not None + assert set(inputs) == { + "pooler_end_indices", + "encoding_pair", + "pooler_pair_end_indices", + "pooler_start_indices", + "encoding", + "pooler_pair_start_indices", + } + torch.testing.assert_close( + inputs["encoding"]["input_ids"], + torch.tensor( + [[101, 1262, 1131, 1771, 140, 119, 102], [101, 1262, 1131, 1771, 140, 119, 102]] + ), + ) + torch.testing.assert_close( + inputs["encoding"]["token_type_ids"], torch.zeros_like(inputs["encoding"]["input_ids"]) + ) + torch.testing.assert_close( + inputs["encoding"]["attention_mask"], torch.ones_like(inputs["encoding"]["input_ids"]) + ) + + torch.testing.assert_close( + inputs["encoding_pair"]["input_ids"], + torch.tensor( + [ + [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102], + [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102], + ] + ), + ) + torch.testing.assert_close( + inputs["encoding_pair"]["token_type_ids"], + torch.zeros_like(inputs["encoding_pair"]["input_ids"]), + ) + torch.testing.assert_close( + inputs["encoding_pair"]["attention_mask"], + torch.ones_like(inputs["encoding_pair"]["input_ids"]), + ) + + torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [4]])) + torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [5]])) + torch.testing.assert_close(inputs["pooler_pair_start_indices"], torch.tensor([[1], [7]])) + torch.testing.assert_close(inputs["pooler_pair_end_indices"], torch.tensor([[5], [8]])) + + torch.testing.assert_close(targets, {"scores": torch.tensor([1.0, 0.0])}) + + +@pytest.fixture(scope="module") +def unbatched_output(taskmodule): + model_output = { + "scores": torch.tensor([0.5338148474693298, 0.9866107940673828]), + } + return taskmodule.unbatch_output(model_output=model_output) + + +def test_unbatch_output(unbatched_output, taskmodule): + assert len(unbatched_output) == 2 + assert unbatched_output == [ + {"is_similar": False, "score": 0.5338148474693298}, + {"is_similar": True, "score": 0.9866107702255249}, + ] + + +def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_output): + all_new_annotations = [] + for task_encoding, task_output in zip(task_encodings, unbatched_output): + for new_annotation in taskmodule.create_annotations_from_output( + task_encoding=task_encoding, task_output=task_output + ): + all_new_annotations.append(new_annotation) + assert all(layer_name == "binary_coref_relations" for layer_name, ann in all_new_annotations) + resolve_annotations_with_scores = [ + (round(ann.score, 4), ann.resolve()) for layer_name, ann in all_new_annotations + ] + assert resolve_annotations_with_scores == [ + (0.9866, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ] + + +def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: + if isinstance(metric_or_collection, Metric): + return flatten_dict(metric_or_collection.metric_state) + elif isinstance(metric_or_collection, MetricCollection): + return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()}) + else: + raise ValueError(f"unsupported type: {type(metric_or_collection)}") + + +def test_configure_metric(taskmodule, batch): + metric = taskmodule.configure_model_metric(stage="train") + + assert isinstance(metric, (Metric, MetricCollection)) + state = get_metric_state(metric) + torch.testing.assert_close( + state, + { + "continuous/auroc/preds": [], + "continuous/auroc/target": [], + "continuous/avg-P/preds": [], + "continuous/avg-P/target": [], + "continuous/f1/fn": tensor([0]), + "continuous/f1/fp": tensor([0]), + "continuous/f1/tn": tensor([0]), + "continuous/f1/tp": tensor([0]), + }, + ) + + # targets = batch[1] + targets = { + "scores": torch.tensor([0.0, 1.0, 0.0, 0.0]), + } + metric.update(targets, targets) + + state = get_metric_state(metric) + torch.testing.assert_close( + state, + { + "continuous/auroc/preds": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/auroc/target": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/avg-P/preds": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/avg-P/target": [tensor([0.0, 1.0, 0.0, 0.0])], + "continuous/f1/tp": tensor([1]), + "continuous/f1/fp": tensor([0]), + "continuous/f1/tn": tensor([3]), + "continuous/f1/fn": tensor([0]), + }, + ) + + torch.testing.assert_close( + metric.compute(), + {"auroc": tensor(1.0), "avg-P": tensor(1.0), "f1": tensor(1.0)}, + ) + + # torch.rand_like(targets) + random_targets = { + "scores": torch.tensor([0.2703, 0.6812, 0.2582, 0.9030]), + } + metric.update(random_targets, targets) + state = get_metric_state(metric) + torch.testing.assert_close( + state, + { + "continuous/auroc/preds": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.2703, 0.6812, 0.2582, 0.9030]), + ], + "continuous/auroc/target": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.0, 1.0, 0.0, 0.0]), + ], + "continuous/avg-P/preds": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.2703, 0.6812, 0.2582, 0.9030]), + ], + "continuous/avg-P/target": [ + tensor([0.0, 1.0, 0.0, 0.0]), + tensor([0.0, 1.0, 0.0, 0.0]), + ], + "continuous/f1/tp": tensor([1]), + "continuous/f1/fp": tensor([1]), + "continuous/f1/tn": tensor([5]), + "continuous/f1/fn": tensor([1]), + }, + ) + + torch.testing.assert_close( + metric.compute(), + {"auroc": tensor(0.91666663), "avg-P": tensor(0.83333337), "f1": tensor(0.50000000)}, + ) diff --git a/tests/utils/test_tokenization.py b/tests/utils/test_tokenization.py new file mode 100644 index 000000000..0b51ce62d --- /dev/null +++ b/tests/utils/test_tokenization.py @@ -0,0 +1,55 @@ +import pytest +from pytorch_ie.annotations import Span +from transformers import AutoTokenizer + +from pie_modules.utils.tokenization import ( + SpanNotAlignedWithTokenException, + get_aligned_token_span, +) + + +def test_get_aligned_token_span(): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + + text = "Hello, world!" + encoding = tokenizer(text) + tokens = tokenizer.convert_ids_to_tokens(encoding.input_ids) + assert tokens == ["[CLS]", "Hello", ",", "world", "!", "[SEP]"] + + # already aligned + char_span = Span(0, 5) + assert text[char_span.start : char_span.end] == "Hello" + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == ["Hello"] + + # end not aligned + char_span = Span(5, 7) + assert text[char_span.start : char_span.end] == ", " + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == [","] + + # start not aligned + char_span = Span(6, 12) + assert text[char_span.start : char_span.end] == " world" + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == ["world"] + + # start not aligned, end inside token + char_span = Span(6, 8) + assert text[char_span.start : char_span.end] == " w" + token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) + assert tokens[token_span.start : token_span.end] == ["world"] + + # empty char span + char_span = Span(2, 2) + assert text[char_span.start : char_span.end] == "" + with pytest.raises(SpanNotAlignedWithTokenException) as e: + get_aligned_token_span(encoding=encoding, char_span=char_span) + assert e.value.span == char_span + + # empty token span + char_span = Span(6, 7) + assert text[char_span.start : char_span.end] == " " + with pytest.raises(SpanNotAlignedWithTokenException) as e: + get_aligned_token_span(encoding=encoding, char_span=char_span) + assert e.value.span == char_span