Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

task: cross-text-coref #110

Merged
merged 49 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3923dbe
implement CrossTextBinaryCorefTaskModule
ArneBinder Sep 11, 2024
52a8b45
call save_hyperparameters()
ArneBinder Sep 11, 2024
71fb6e1
make taskmodule (future) model compliant
ArneBinder Sep 11, 2024
abec896
implement SimpleSimilarityModel
ArneBinder Sep 11, 2024
16caf81
use fixture data for documents_with_negatives
ArneBinder Sep 12, 2024
b2f739b
disentangle tests
ArneBinder Sep 12, 2024
119ef75
streamline test
ArneBinder Sep 12, 2024
6b90f99
improve test
ArneBinder Sep 12, 2024
5067e16
create negatives from text to itself (but different spans)
ArneBinder Sep 12, 2024
a93abde
restrict candidates by having same entity type
ArneBinder Sep 12, 2024
ddadf5e
make RelationStatisticsMixin ready for multi-label or binary
ArneBinder Sep 12, 2024
4768fe0
use RelationStatisticsMixin
ArneBinder Sep 12, 2024
5d0c7fa
rename model to SpanSimilarityModel; add similarity_threshold paramet…
ArneBinder Sep 12, 2024
cc2e828
remove SpanSimilarityModel in favor of new SequencePairSimilarityMode…
ArneBinder Sep 12, 2024
64fa61e
add tests for SequencePairSimilarityModelWithPooler
ArneBinder Sep 12, 2024
41af7ce
use mention pooling per default
ArneBinder Sep 12, 2024
ad89617
make pre-commit happy
ArneBinder Sep 12, 2024
2a974ea
implement unbatch_output() and create_annotations_from_output()
ArneBinder Sep 12, 2024
6a2e878
implement long text handling
ArneBinder Sep 12, 2024
617be19
set default label_threshold for SequencePairSimilarityModelWithPooler…
ArneBinder Sep 12, 2024
309346a
fix missed index shift because of added special tokens
ArneBinder Sep 12, 2024
8681f24
minor fixes
ArneBinder Sep 12, 2024
4376d16
add check for direction of coref relations (should point from text to…
ArneBinder Sep 12, 2024
ebd09b2
add documentation for SequencePairSimilarityModelWithPooler
ArneBinder Sep 12, 2024
98cc08f
add model and taskmodule to readme
ArneBinder Sep 12, 2024
9599ee2
add short documenation for CrossTextBinaryCorefTaskModule
ArneBinder Sep 12, 2024
0cb1708
outsource add_negative_relations() to document.precessing.text_pair
ArneBinder Sep 13, 2024
6b0c032
update documents_with_negatives.json with current output
ArneBinder Sep 13, 2024
1901779
rename add_negative_relations() to add_negative_coref_relations() and…
ArneBinder Sep 13, 2024
8093a58
implement construct_text_pair_coref_documents_from_partitions_via_rel…
ArneBinder Sep 13, 2024
c0c2875
add tqdm to add_negative_coref_relations
ArneBinder Sep 13, 2024
81ea67c
move document and annotation types to documents and annotations modul…
ArneBinder Sep 15, 2024
b1ceac2
fix tokenization in encode_input
ArneBinder Sep 15, 2024
b5f0c8b
outsource get_aligned_token_span() and SpanNotAlignedWithTokenExcepti…
ArneBinder Sep 15, 2024
7487fd7
implement construct_text_document_from_text_pair_coref_document()
ArneBinder Sep 15, 2024
961a2a0
add parameter relation_label_mapping to construct_text_document_from_…
ArneBinder Sep 15, 2024
1e3b5d4
add parameter no_relation_label to construct_text_document_from_text_…
ArneBinder Sep 15, 2024
b8d9939
prepare downsampling of negatives
ArneBinder Sep 15, 2024
1d73f8a
add_negative_coref_relations does not return docs without relations
ArneBinder Sep 15, 2024
11de48b
implement downsampling for add_negative_coref_relations()
ArneBinder Sep 15, 2024
538e9ee
remove unused kwargs from add_negative_coref_relations()
ArneBinder Sep 15, 2024
833eb8c
allow that downsampling negatives does nto produce negatives at all
ArneBinder Sep 16, 2024
e72fcc1
fix existing and add more metrics; rename "labels" / probabilities" …
ArneBinder Sep 16, 2024
c294429
move label_threshold from model to taskmodule; rename "is_valid" to "…
ArneBinder Sep 16, 2024
7c68696
rename parameter "label_threshold" to "similarity_threshold"
ArneBinder Sep 16, 2024
049e69e
add metric: avg-P (BinaryAveragePrecision)
ArneBinder Sep 16, 2024
a30fc2f
use BinaryF1Score instead of MultiClassF1 (micro/macro/per-label)
ArneBinder Sep 16, 2024
904fac9
fix test
ArneBinder Sep 16, 2024
c76caca
cleanup
ArneBinder Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Available models:

- [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py)
- [SequenceClassificationModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py)
- [SequencePairSimilarityModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py)
- [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py)
- [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py)
- [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py)
Expand All @@ -25,6 +26,7 @@ Available models:
Available taskmodules:

- [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py)
- [CrossTextBinaryCorefTaskModule](src/pie_modules/taskmodules/cross_text_binary_coref.py)
- [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py)
- [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py)
- [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py)
Expand Down
5 changes: 5 additions & 0 deletions src/pie_modules/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ class GenerativeAnswer(AnnotationWithText):

score: Optional[float] = dataclasses.field(default=None, compare=False)
question: Optional[Question] = None


@dataclasses.dataclass(eq=True, frozen=True)
class BinaryCorefRelation(BinaryRelation):
label: str = "coref"
254 changes: 254 additions & 0 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import copy
import logging
import random
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar

from pytorch_ie.annotations import LabeledSpan, Span
from pytorch_ie.documents import (
TextDocumentWithLabeledSpansAndBinaryRelations,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from tqdm import tqdm

from pie_modules.documents import (
BinaryCorefRelation,
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
)
from pie_modules.utils.span import are_nested

logger = logging.getLogger(__name__)

S = TypeVar("S", bound=Span)
S2 = TypeVar("S2", bound=Span)


def _span2partition_mapping(spans: Iterable[S], partitions: Iterable[S2]) -> Dict[S, S2]:
result = {}
for span in spans:
for partition in partitions:
if are_nested(
start_end=(span.start, span.end), other_start_end=(partition.start, partition.end)
):
result[span] = partition
break
return result


def _span_copy_shifted(span: S, offset: int) -> S:
return span.copy(start=span.start + offset, end=span.end + offset)


def _construct_text_pair_coref_documents_from_partitions_via_relations(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, relation_label: str
) -> List[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
span2partition = _span2partition_mapping(
spans=document.labeled_spans, partitions=document.labeled_partitions
)
partition2spans = defaultdict(list)
for span, partition in span2partition.items():
partition2spans[partition].append(span)

texts2docs_and_span_mappings: Dict[
Tuple[str, str],
Tuple[
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
Dict[LabeledSpan, LabeledSpan],
Dict[LabeledSpan, LabeledSpan],
],
] = dict()
result = []
for rel in document.binary_relations:
if rel.label != relation_label:
continue

if rel.head not in span2partition:
raise ValueError(f"head not in any partition: {rel.head}")
head_partition = span2partition[rel.head]
text = document.text[head_partition.start : head_partition.end]

if rel.tail not in span2partition:
raise ValueError(f"tail not in any partition: {rel.tail}")
tail_partition = span2partition[rel.tail]
text_pair = document.text[tail_partition.start : tail_partition.end]

if (text, text_pair) in texts2docs_and_span_mappings:
new_doc, head_spans_mapping, tail_spans_mapping = texts2docs_and_span_mappings[
(text, text_pair)
]
else:
if document.id is not None:
doc_id = (
f"{document.id}[{head_partition.start}:{head_partition.end}]"
f"+{document.id}[{tail_partition.start}:{tail_partition.end}]"
)
else:
doc_id = None
new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
id=doc_id, text=text, text_pair=text_pair
)

head_spans_mapping = {
span: _span_copy_shifted(span=span, offset=-head_partition.start)
for span in partition2spans[head_partition]
}
new_doc.labeled_spans.extend(head_spans_mapping.values())

tail_spans_mapping = {
span: _span_copy_shifted(span=span, offset=-tail_partition.start)
for span in partition2spans[tail_partition]
}
new_doc.labeled_spans_pair.extend(tail_spans_mapping.values())

texts2docs_and_span_mappings[(text, text_pair)] = (
new_doc,
head_spans_mapping,
tail_spans_mapping,
)
result.append(new_doc)

coref_rel = BinaryCorefRelation(
head=head_spans_mapping[rel.head], tail=tail_spans_mapping[rel.tail], score=1.0
)
new_doc.binary_coref_relations.append(coref_rel)

return result


def construct_text_pair_coref_documents_from_partitions_via_relations(
documents: Iterable[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions], **kwargs
) -> Iterator[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
for doc in documents:
yield from _construct_text_pair_coref_documents_from_partitions_via_relations(
document=doc, **kwargs
)


def shift_span(span: S, offset: int) -> S:
return span.copy(start=span.start + offset, end=span.end + offset)


def construct_text_document_from_text_pair_coref_document(
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
glue_text: str,
no_relation_label: str,
relation_label_mapping: Optional[Dict[str, str]] = None,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
if document.text == document.text_pair:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
)
old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
new_span = old_span.copy()
# when detaching / copying the span, it may be the same as a previous span from the other
new_span = new2new_spans.get(new_span, new_span)
new2new_spans[new_span] = new_span
old2new_spans[old_span] = new_span
else:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
text=document.text + glue_text + document.text_pair,
id=document.id,
metadata=copy.deepcopy(document.metadata),
)
old2new_spans = {}
old2new_spans.update({span: span.copy() for span in document.labeled_spans})
offset = len(document.text) + len(glue_text)
old2new_spans.update(
{span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
)

# sort to make order deterministic
new_doc.labeled_spans.extend(
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
)
for old_rel in document.binary_coref_relations:
label = old_rel.label if old_rel.score > 0.0 else no_relation_label
if relation_label_mapping is not None:
label = relation_label_mapping.get(label, label)
new_rel = old_rel.copy(
head=old2new_spans[old_rel.head],
tail=old2new_spans[old_rel.tail],
label=label,
score=1.0,
)
new_doc.binary_relations.append(new_rel)

return new_doc


def add_negative_coref_relations(
documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations],
downsampling_factor: Optional[float] = None,
) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
positive_tuples = defaultdict(set)
text2spans = defaultdict(set)
for doc in documents:
for labeled_span in doc.labeled_spans:
text2spans[doc.text].add(labeled_span.copy())
for labeled_span in doc.labeled_spans_pair:
text2spans[doc.text_pair].add(labeled_span.copy())

for coref in doc.binary_coref_relations:
positive_tuples[(doc.text, doc.text_pair)].add((coref.head.copy(), coref.tail.copy()))
positive_tuples[(doc.text_pair, doc.text)].add((coref.tail.copy(), coref.head.copy()))

new_docs = []
new_rels2new_docs = {}
positive_rels = []
negative_rels = []
for text in tqdm(sorted(text2spans)):
for text_pair in sorted(text2spans):
current_positives = positive_tuples.get((text, text_pair), set())
new_doc = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
text=text, text_pair=text_pair
)
new_doc.labeled_spans.extend(labeled_span.copy() for labeled_span in text2spans[text])
new_doc.labeled_spans_pair.extend(
labeled_span.copy() for labeled_span in text2spans[text_pair]
)
for s in sorted(new_doc.labeled_spans):
for s_p in sorted(new_doc.labeled_spans_pair):
# exclude relations to itself
if text == text_pair and s.copy() == s_p.copy():
continue
if s.label != s_p.label:
continue
score = 1.0 if (s.copy(), s_p.copy()) in current_positives else 0.0
new_coref_rel = BinaryCorefRelation(head=s, tail=s_p, score=score)
# new_doc.binary_coref_relations.append(new_coref_rel)
new_rels2new_docs[new_coref_rel] = new_doc
if score > 0.0:
positive_rels.append(new_coref_rel)
else:
negative_rels.append(new_coref_rel)
new_docs.append(new_doc)

for rel in positive_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

# Downsampling of negatives. This requires positive instances!
if downsampling_factor is not None:
if len(positive_rels) == 0:
raise ValueError(
f"downsampling [factor={downsampling_factor}] is enabled, "
f"but no positive relations are available to calculate max_num_negative"
)

max_num_negative = int(len(positive_rels) * downsampling_factor)
if max_num_negative == 0:
logger.warning(
f"downsampling with factor={downsampling_factor} and number of "
f"positive relations={len(positive_rels)} does not produce any negatives"
)
else:
random.shuffle(negative_rels)
negative_rels = negative_rels[:max_num_negative]
for rel in negative_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0]
return docs_with_rels
61 changes: 61 additions & 0 deletions src/pie_modules/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from pie_modules.annotations import (
AbstractiveSummary,
BinaryCorefRelation,
BinaryRelation,
ExtractiveAnswer,
GenerativeAnswer,
Expand Down Expand Up @@ -151,3 +152,63 @@ class TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions(
TokenDocumentWithLabeledMultiSpansAndBinaryRelations,
):
pass


@dataclasses.dataclass
class WithTextPair:
text_pair: str


@dataclasses.dataclass
class WithLabeledSpansPair(WithTextPair):
labeled_spans_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair")


@dataclasses.dataclass
class WithLabeledPartitionsPair(WithTextPair):
labeled_partitions_pair: AnnotationLayer[LabeledSpan] = annotation_field(target="text_pair")


@dataclasses.dataclass
class TextPairBasedDocument(TextBasedDocument, WithTextPair):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledPartitions(
WithLabeledPartitionsPair, TextPairBasedDocument, TextDocumentWithLabeledPartitions
):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpans(
WithLabeledSpansPair, TextPairBasedDocument, TextDocumentWithLabeledSpans
):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpansAndLabeledPartitions(
TextPairDocumentWithLabeledPartitions,
TextPairDocumentWithLabeledSpans,
TextDocumentWithLabeledSpansAndLabeledPartitions,
):
pass


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
TextPairDocumentWithLabeledSpans, TextDocumentWithLabeledSpans
):
binary_coref_relations: AnnotationLayer[BinaryCorefRelation] = annotation_field(
targets=["labeled_spans", "labeled_spans_pair"]
)


@dataclasses.dataclass
class TextPairDocumentWithLabeledSpansSimilarityRelationsAndLabeledPartitions(
TextPairDocumentWithLabeledSpansAndLabeledPartitions,
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
):
pass
5 changes: 4 additions & 1 deletion src/pie_modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .sequence_classification_with_pooler import SequenceClassificationModelWithPooler
from .sequence_classification_with_pooler import (
SequenceClassificationModelWithPooler,
SequencePairSimilarityModelWithPooler,
)
from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel
from .simple_generative import SimpleGenerativeModel
from .simple_sequence_classification import SimpleSequenceClassificationModel
Expand Down
Loading
Loading