Skip to content

Commit

Permalink
Merge pull request #110 from ArneBinder/taskmodules/cross_text_binary…
Browse files Browse the repository at this point in the history
…_coref

task: cross-text-coref
  • Loading branch information
ArneBinder authored Sep 16, 2024
2 parents 2ddb8a5 + c76caca commit cccdd33
Show file tree
Hide file tree
Showing 15 changed files with 2,731 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Available models:

- [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py)
- [SequenceClassificationModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py)
- [SequencePairSimilarityModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py)
- [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py)
- [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py)
- [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py)
Expand All @@ -25,6 +26,7 @@ Available models:
Available taskmodules:

- [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py)
- [CrossTextBinaryCorefTaskModule](src/pie_modules/taskmodules/cross_text_binary_coref.py)
- [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py)
- [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py)
- [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py)
Expand Down
5 changes: 5 additions & 0 deletions src/pie_modules/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ class GenerativeAnswer(AnnotationWithText):

score: Optional[float] = dataclasses.field(default=None, compare=False)
question: Optional[Question] = None


@dataclasses.dataclass(eq=True, frozen=True)
class BinaryCorefRelation(BinaryRelation):
label: str = "coref"
254 changes: 254 additions & 0 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import copy
import logging
import random
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar

from pytorch_ie.annotations import LabeledSpan, Span
from pytorch_ie.documents import (
TextDocumentWithLabeledSpansAndBinaryRelations,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from tqdm import tqdm

from pie_modules.documents import (
BinaryCorefRelation,
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
)
from pie_modules.utils.span import are_nested

logger = logging.getLogger(__name__)

S = TypeVar("S", bound=Span)
S2 = TypeVar("S2", bound=Span)


def _span2partition_mapping(spans: Iterable[S], partitions: Iterable[S2]) -> Dict[S, S2]:
result = {}
for span in spans:
for partition in partitions:
if are_nested(
start_end=(span.start, span.end), other_start_end=(partition.start, partition.end)
):
result[span] = partition
break
return result


def _span_copy_shifted(span: S, offset: int) -> S:
return span.copy(start=span.start + offset, end=span.end + offset)


def _construct_text_pair_coref_documents_from_partitions_via_relations(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, relation_label: str
) -> List[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
span2partition = _span2partition_mapping(
spans=document.labeled_spans, partitions=document.labeled_partitions
)
partition2spans = defaultdict(list)
for span, partition in span2partition.items():
partition2spans[partition].append(span)

texts2docs_and_span_mappings: Dict[
Tuple[str, str],
Tuple[
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
Dict[LabeledSpan, LabeledSpan],
Dict[LabeledSpan, LabeledSpan],
],
] = dict()
result = []
for rel in document.binary_relations:
if rel.label != relation_label:
continue

if rel.head not in span2partition:
raise ValueError(f"head not in any partition: {rel.head}")
head_partition = span2partition[rel.head]
text = document.text[head_partition.start : head_partition.end]

if rel.tail not in span2partition:
raise ValueError(f"tail not in any partition: {rel.tail}")
tail_partition = span2partition[rel.tail]
text_pair = document.text[tail_partition.start : tail_partition.end]

if (text, text_pair) in texts2docs_and_span_mappings:
new_doc, head_spans_mapping, tail_spans_mapping = texts2docs_and_span_mappings[
(text, text_pair)
]
else:
if document.id is not None:
doc_id = (
f"{document.id}[{head_partition.start}:{head_partition.end}]"
f"+{document.id}[{tail_partition.start}:{tail_partition.end}]"
)
else:
doc_id = None
new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
id=doc_id, text=text, text_pair=text_pair
)

head_spans_mapping = {
span: _span_copy_shifted(span=span, offset=-head_partition.start)
for span in partition2spans[head_partition]
}
new_doc.labeled_spans.extend(head_spans_mapping.values())

tail_spans_mapping = {
span: _span_copy_shifted(span=span, offset=-tail_partition.start)
for span in partition2spans[tail_partition]
}
new_doc.labeled_spans_pair.extend(tail_spans_mapping.values())

texts2docs_and_span_mappings[(text, text_pair)] = (
new_doc,
head_spans_mapping,
tail_spans_mapping,
)
result.append(new_doc)

coref_rel = BinaryCorefRelation(
head=head_spans_mapping[rel.head], tail=tail_spans_mapping[rel.tail], score=1.0
)
new_doc.binary_coref_relations.append(coref_rel)

return result


def construct_text_pair_coref_documents_from_partitions_via_relations(
documents: Iterable[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions], **kwargs
) -> Iterator[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
for doc in documents:
yield from _construct_text_pair_coref_documents_from_partitions_via_relations(
document=doc, **kwargs
)


def shift_span(span: S, offset: int) -> S:
return span.copy(start=span.start + offset, end=span.end + offset)


def construct_text_document_from_text_pair_coref_document(
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
glue_text: str,
no_relation_label: str,
relation_label_mapping: Optional[Dict[str, str]] = None,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
if document.text == document.text_pair:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
)
old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
new_span = old_span.copy()
# when detaching / copying the span, it may be the same as a previous span from the other
new_span = new2new_spans.get(new_span, new_span)
new2new_spans[new_span] = new_span
old2new_spans[old_span] = new_span
else:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
text=document.text + glue_text + document.text_pair,
id=document.id,
metadata=copy.deepcopy(document.metadata),
)
old2new_spans = {}
old2new_spans.update({span: span.copy() for span in document.labeled_spans})
offset = len(document.text) + len(glue_text)
old2new_spans.update(
{span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
)

# sort to make order deterministic
new_doc.labeled_spans.extend(
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
)
for old_rel in document.binary_coref_relations:
label = old_rel.label if old_rel.score > 0.0 else no_relation_label
if relation_label_mapping is not None:
label = relation_label_mapping.get(label, label)
new_rel = old_rel.copy(
head=old2new_spans[old_rel.head],
tail=old2new_spans[old_rel.tail],
label=label,
score=1.0,
)
new_doc.binary_relations.append(new_rel)

return new_doc


def add_negative_coref_relations(
documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations],
downsampling_factor: Optional[float] = None,
) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
positive_tuples = defaultdict(set)
text2spans = defaultdict(set)
for doc in documents:
for labeled_span in doc.labeled_spans:
text2spans[doc.text].add(labeled_span.copy())
for labeled_span in doc.labeled_spans_pair:
text2spans[doc.text_pair].add(labeled_span.copy())

for coref in doc.binary_coref_relations:
positive_tuples[(doc.text, doc.text_pair)].add((coref.head.copy(), coref.tail.copy()))
positive_tuples[(doc.text_pair, doc.text)].add((coref.tail.copy(), coref.head.copy()))

new_docs = []
new_rels2new_docs = {}
positive_rels = []
negative_rels = []
for text in tqdm(sorted(text2spans)):
for text_pair in sorted(text2spans):
current_positives = positive_tuples.get((text, text_pair), set())
new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
text=text, text_pair=text_pair
)
new_doc.labeled_spans.extend(labeled_span.copy() for labeled_span in text2spans[text])
new_doc.labeled_spans_pair.extend(
labeled_span.copy() for labeled_span in text2spans[text_pair]
)
for s in sorted(new_doc.labeled_spans):
for s_p in sorted(new_doc.labeled_spans_pair):
# exclude relations to itself
if text == text_pair and s.copy() == s_p.copy():
continue
if s.label != s_p.label:
continue
score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0
new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score)
# new_doc.binary_coref_relations.append(new_coref_rel)
new_rels2new_docs[new_coref_rel] = new_doc
if score > 0.0:
positive_rels.append(new_coref_rel)
else:
negative_rels.append(new_coref_rel)
new_docs.append(new_doc)

for rel in positive_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

# Downsampling of negatives. This requires positive instances!
if downsampling_factor is not None:
if len(positive_rels) == 0:
raise ValueError(
f"downsampling [factor={downsampling_factor}] is enabled, "
f"but no positive relations are available to calculate max_num_negative"
)

max_num_negative = int(len(positive_rels) * downsampling_factor)
if max_num_negative == 0:
logger.warning(
f"downsampling with factor={downsampling_factor} and number of "
f"positive relations={len(positive_rels)} does not produce any negatives"
)
else:
random.shuffle(negative_rels)
negative_rels = negative_rels[:max_num_negative]
for rel in negative_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0]
return docs_with_rels
61 changes: 61 additions & 0 deletions src/pie_modules/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from pie_modules.annotations import (
AbstractiveSummary,
BinaryCorefRelation,
BinaryRelation,
ExtractiveAnswer,
GenerativeAnswer,
Expand Down Expand Up @@ -151,3 +152,63 @@ class TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions(
TokenDocumentWithLabeledMultiSpansAndBinaryRelations,
):
pass


@dataclasses.dataclass
class WithTextPair:
text_pair: str


@dataclasses.dataclass
class WithLabeledSpansPair(WithTextPair):
labeled_spans_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair")


@dataclasses.dataclass
class WithLabeledPartitionsPair(WithTextPair):
labeled_partitions_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair")


@dataclasses.dataclass
class TextPairBasedDocument(TextBasedDocument, WithTextPair):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledPartitions(
WithLabeledPartitionsPair, TextPairBasedDocument, TextDocumentWithLabeledPartitions
):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpans(
WithLabeledSpansPair, TextPairBasedDocument, TextDocumentWithLabeledSpans
):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpansAndLabeledPartitions(
TextPairDocumentWithLabeledPartitions,
TextPairDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
TextPairDocumentWithLabeledSpans, TextDocumentWithLabeledSpans
):
binary_coref_relations: AnnotationLayer[BinaryCorefRelation] = annotation_field(
targets=["labeled_spans", "labeled_spans_pair"]
)


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpansSimilarityRelationsAndLabeledPartitions(
TextPairDocumentWithLabeledSpansAndLabeledPartitions,
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
):
pass
5 changes: 4 additions & 1 deletion src/pie_modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .sequence_classification_with_pooler import SequenceClassificationModelWithPooler
from .sequence_classification_with_pooler import (
SequenceClassificationModelWithPooler,
SequencePairSimilarityModelWithPooler,
)
from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel
from .simple_generative import SimpleGenerativeModel
from .simple_sequence_classification import SimpleSequenceClassificationModel
Expand Down
Loading

0 comments on commit cccdd33

Please sign in to comment.