diff --git a/src/pie_modules/taskmodules/common/interfaces.py b/src/pie_modules/taskmodules/common/interfaces.py index 7a994d20c..0d5915a2d 100644 --- a/src/pie_modules/taskmodules/common/interfaces.py +++ b/src/pie_modules/taskmodules/common/interfaces.py @@ -1,8 +1,12 @@ import abc +import logging +from collections import defaultdict from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar from pytorch_ie import Annotation +logger = logging.getLogger(__name__) + # Annotation Encoding type: encoding for a single annotation AE = TypeVar("AE") # Annotation type @@ -17,9 +21,20 @@ class DecodingException(Exception, Generic[AE], abc.ABC): identifier: str - def __init__(self, message: str, encoding: AE): + def __init__(self, message: str, encoding: AE, remaining: Optional[AE] = None): self.message = message self.encoding = encoding + self.remaining = remaining + + +class EncodingException(Exception, Generic[A], abc.ABC): + """Exception raised when encoding fails.""" + + identifier: str + + def __init__(self, message: str, annotation: A): + self.message = message + self.annotation = annotation class AnnotationEncoderDecoder(abc.ABC, Generic[A, AE]): @@ -32,3 +47,64 @@ def encode(self, annotation: A, metadata: Optional[Dict[str, Any]] = None) -> AE @abc.abstractmethod def decode(self, encoding: AE, metadata: Optional[Dict[str, Any]] = None) -> A: pass + + +class GenerativeAnnotationEncoderDecoder(AnnotationEncoderDecoder[A, AE], abc.ABC): + """Base class for generative annotation encoders and decoders.""" + + @abc.abstractmethod + def parse(self, encoding: AE, decoded_annotations: List[A], text_length: int) -> Tuple[A, AE]: + """Parse the encoding and return the decoded annotation and the remaining encoding.""" + pass + + +class GenerativeAnnotationEncoderDecoderWithParseWithErrors( + Generic[A], GenerativeAnnotationEncoderDecoder[A, List[int]], abc.ABC +): + KEY_INVALID_CORRECT = "correct" + + def parse_with_error_handling( + self, + encoding: List[int], + input_length: int, + stop_ids: List[int], + errors: Optional[Dict[str, int]] = None, + decoded_annotations: Optional[List[A]] = None, + disrespect_decoded_annotations: bool = False, + ) -> Tuple[List[A], Dict[str, int], List[int]]: + errors = errors or defaultdict(int) + decoded_annotations = decoded_annotations or [] + valid_encoding: A + successfully_decoded: List[int] = [] + remaining = encoding + prev_len = len(remaining) + while len(remaining) > 0: + if remaining[0] in stop_ids: + # we discard everything after any stop id + break + try: + valid_encoding, remaining = self.parse( + encoding=remaining, + decoded_annotations=decoded_annotations + if not disrespect_decoded_annotations + else [], + text_length=input_length, + ) + decoded_annotations.append(valid_encoding) + errors[self.KEY_INVALID_CORRECT] += 1 + successfully_decoded = encoding[: len(encoding) - len(remaining)] + except DecodingException as e: + if e.remaining is None: + raise ValueError(f"decoding exception did not return remaining encoding: {e}") + errors[e.identifier] += 1 + remaining = e.remaining + + # if we did not consume any ids, we discard the first remaining one + if len(remaining) == prev_len: + logger.warning( + f"parse did not consume any ids, discarding first id from {remaining}" + ) + remaining = remaining[1:] + prev_len = len(remaining) + + return decoded_annotations, dict(errors), encoding[len(successfully_decoded) :] diff --git a/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py b/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py index 66304a646..cfbc13f0f 100644 --- a/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py +++ b/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py @@ -1,10 +1,18 @@ import logging -from typing import Any, Dict, List, Optional, Set +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple, Type +from pytorch_ie import Annotation from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span -from pie_modules.taskmodules.common import AnnotationEncoderDecoder -from pie_modules.taskmodules.common.interfaces import DecodingException +from pie_modules.annotations import LabeledMultiSpan +from pie_modules.taskmodules.common.interfaces import ( + DecodingException, + EncodingException, + GenerativeAnnotationEncoderDecoder, + GenerativeAnnotationEncoderDecoderWithParseWithErrors, +) +from pie_modules.utils.span import are_nested, have_overlap logger = logging.getLogger(__name__) @@ -21,19 +29,100 @@ class DecodingSpanOverlapException(DecodingException[List[int]]): identifier = "overlap" +class DecodingSpanNestedException(DecodingException[List[int]]): + identifier = "nested" + + class DecodingLabelException(DecodingException[List[int]]): identifier = "label" class DecodingNegativeIndexException(DecodingException[List[int]]): - identifier = "index" + identifier = "negative_index" + + +class DecodingEmptySpanException(DecodingException[List[int]]): + identifier = "empty_span" + + +class IncompleteEncodingException(DecodingException[List[int]]): + identifier = "incomplete" + + def __init__(self, message: str, encoding: List[int], follow_up_candidates: List[int]): + super().__init__(message, encoding, remaining=[]) + self.follow_up_candidates = follow_up_candidates + + +class EncodingEmptySpanException(EncodingException[Span]): + identifier = "empty_span" + + +class EncodingEmptySlicesException(EncodingException[LabeledMultiSpan]): + identifier = "empty_slices" + + +def spans_have_overlap(span: Span, other_span: Span) -> bool: + start_end = (span.start, span.end) + other_start_end = (other_span.start, other_span.end) + return have_overlap(start_end=start_end, other_start_end=other_start_end) and not are_nested( + start_end=start_end, other_start_end=other_start_end + ) -class SpanEncoderDecoder(AnnotationEncoderDecoder[Span, List[int]]): - def __init__(self, exclusive_end: bool = True): +def spans_are_nested(span: Span, other_span: Span) -> bool: + return are_nested( + start_end=(span.start, span.end), other_start_end=(other_span.start, other_span.end) + ) + + +def _parse_label( + encoding: List[int], id2label: Dict[int, str], annotation_type: Type[Annotation] +) -> Tuple[str, List[int]]: + if len(encoding) == 0: + raise IncompleteEncodingException( + f"the encoding has not enough values to decode as {annotation_type.__name__}", + encoding=encoding, + follow_up_candidates=sorted(id2label.keys()), + ) + label_encoding = encoding[0] + remaining = encoding[1:] + if label_encoding not in id2label: + raise DecodingLabelException( + f"unknown label id: {label_encoding} (id2label: {id2label})", + encoding=encoding, + remaining=remaining, + ) + label = id2label[label_encoding] + return label, remaining + + +class SpanEncoderDecoder(GenerativeAnnotationEncoderDecoder[Span, List[int]]): + """An encoder-decoder for Spans. + + This encoder-decoder encodes a Span annotation as a list of two integers, the start and end index + of the span. Note that the end index of the Span annotation is exclusive, i.e. the span covers the + indices [start, end). However, the end index can be encoded as exclusive or inclusive, depending on + the `exclusive_end` parameter. Note that empty spans are not allowed, i.e. the start index must be + smaller than the end index. + + Args: + exclusive_end (bool, optional): Whether the end index will be encoded as exclusive or inclusive, i.e. the + encoded span covers the indices [encoded_start, encoded_end) or [encoded_start, encoded_end]. + Defaults to True. + allow_nested (bool, optional): Whether nested spans are allowed during parsing. If set to False, parsing + will raise an exception if a span is completely within another span. Defaults to False. + """ + + def __init__(self, exclusive_end: bool = True, allow_nested: bool = False): self.exclusive_end = exclusive_end + self.allow_nested = allow_nested def encode(self, annotation: Span, metadata: Optional[Dict[str, Any]] = None) -> List[int]: + if annotation.start == annotation.end: + raise EncodingEmptySpanException( + "can not encode empty Span annotations, i.e. where the start index equals the end index", + annotation=annotation, + ) end_idx = annotation.end if not self.exclusive_end: end_idx -= 1 @@ -60,6 +149,148 @@ def decode(self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None) ) return Span(start=encoding[0], end=end_idx) + def parse( + self, + encoding: List[int], + decoded_annotations: List[Span], + text_length: int, + ) -> Tuple[Span, List[int]]: + exclusive_end_offset = 0 if self.exclusive_end else 1 + # the encoding is incomplete if it is empty, collect follow-up candidate indices + if len(encoding) == 0: + if self.allow_nested: + # everything is allowed + follow_up_candidates = list(range(text_length)) + else: + # exclude indices that are already covered by other annotations + nested_indices: Set[int] = set() + for previous_span in decoded_annotations: + # +1 because we allow to generate the exact same spans again + nested_indices.update(range(previous_span.start + 1, previous_span.end)) + follow_up_candidates = [ + idx for idx in range(text_length) if idx not in nested_indices + ] + raise IncompleteEncodingException( + "the encoding has not enough values to decode as Span", + encoding=encoding, + follow_up_candidates=follow_up_candidates, + ) + # the encoding is incomplete if it has only one value, collect follow-up candidate indices + elif len(encoding) == 1: + covering_spans = { + ann for ann in decoded_annotations if ann.start <= encoding[0] < ann.end + } + if self.allow_nested: + # exclude spans that overlap other spans, i.e. if encoding[0] is in another span, the next + # candidate should be also within this span + if len(covering_spans) == 0: + # allow all indices outside spans, after the start index + nested_indices = set() + for span in decoded_annotations: + # -1 because the end is outside the span + nested_indices = nested_indices.union(set(range(span.start, span.end - 1))) + + follow_up_candidates = [ + idx + 1 - exclusive_end_offset + for idx in range(encoding[0], text_length) + if idx not in nested_indices + ] + else: + # allow all indices that are within *all* covering spans, i.e. the smallest covering span, + # and after the start index + nested_indices = set(range(0, text_length)) + for span in covering_spans: + # + 1 because we want to include the (exclusive) end index + nested_indices = nested_indices.intersection( + set(range(span.start, span.end + 1)) + ) + follow_up_candidates = [ + idx - exclusive_end_offset for idx in nested_indices if idx > encoding[0] + ] + elif len(covering_spans) > 0: + if len(covering_spans) > 1: + raise ValueError( + f"more than one covering span found, but allow_nested=False. This should not happen. " + f"covering spans: {covering_spans}" + ) + covering_span = list(covering_spans)[0] + # if we generated the start of an existing span, we need to generate the exact end next + follow_up_candidates = [covering_span.end - exclusive_end_offset] + else: + # allow all indices after the start index and before the next span. we add a dummy span to + # correctly handle the case where no other spans are present + dummy_span = Span(start=text_length, end=text_length + 1) + next_span_start = min( + ann.start + for ann in decoded_annotations + [dummy_span] + if encoding[0] <= ann.start + ) + # +1 because we disallow empty spans + min_index = encoding[0] + 1 + # +1 because the end index is exclusive + max_index_exclusive = next_span_start + 1 + follow_up_candidates = list( + range( + min_index - exclusive_end_offset, + max_index_exclusive - exclusive_end_offset, + ) + ) + raise IncompleteEncodingException( + "the encoding has not enough values to decode as Span", + encoding=encoding, + follow_up_candidates=follow_up_candidates, + ) + # the encoding is complete, try to decode the span + else: + start_idx = encoding[0] + end_idx = encoding[1] + remaining = encoding[2:] + # the end index for Span annotations is exclusive, so we need to add 1 to the end index + if not self.exclusive_end: + end_idx += 1 + if end_idx == start_idx: + raise DecodingEmptySpanException( + "end index can not be equal to start index to decode as Span, but got: " + f"start={start_idx}, end={end_idx}", + encoding=encoding, + remaining=remaining, + ) + if end_idx < start_idx: + raise DecodingOrderException( + f"end index can not be smaller than start index, " + f"but got: start={start_idx}, end={end_idx}", + encoding=encoding, + remaining=remaining, + ) + if any(idx < 0 for idx in [start_idx, end_idx]): + raise DecodingNegativeIndexException( + f"indices must be positive, but got: start={start_idx}, end={end_idx}", + encoding=encoding, + remaining=remaining, + ) + # check overlap and nesting with previously decoded spans + span = Span(start=start_idx, end=end_idx) + for previous_span in decoded_annotations: + simple_previous_span = Span(start=previous_span.start, end=previous_span.end) + if span != simple_previous_span: + if spans_have_overlap(span=span, other_span=simple_previous_span): + raise DecodingSpanOverlapException( + f"the encoded span overlaps with another span: {previous_span}", + encoding=encoding, + remaining=remaining, + ) + if not self.allow_nested and spans_are_nested( + span=span, other_span=simple_previous_span + ): + raise DecodingSpanNestedException( + f"the encoded span is nested in another span: {previous_span}. " + "You can set allow_nested=True to allow nested spans.", + encoding=encoding, + remaining=remaining, + ) + + return span, remaining + class SpanEncoderDecoderWithOffset(SpanEncoderDecoder): def __init__(self, offset: int, **kwargs): @@ -74,11 +305,37 @@ def decode(self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None) encoding = [x - self.offset for x in encoding] return super().decode(encoding=encoding, metadata=metadata) - -class LabeledSpanEncoderDecoder(AnnotationEncoderDecoder[LabeledSpan, List[int]]): + def parse( + self, + encoding: List[int], + decoded_annotations: List[Span], + text_length: int, + ) -> Tuple[Span, List[int]]: + encoding_without_offset = [x - self.offset for x in encoding] + try: + span, remaining = super().parse( + encoding=encoding_without_offset, + decoded_annotations=decoded_annotations, + text_length=text_length, + ) + # and also to the follow-up candidates if present + except DecodingException as e: + # we need to add the offset to the remaining encoding + # and also to the follow-up candidates if any of them is present + kwargs = {} + if e.remaining is not None and not isinstance(e, IncompleteEncodingException): + kwargs["remaining"] = [x + self.offset for x in e.remaining] + if isinstance(e, IncompleteEncodingException): + kwargs["follow_up_candidates"] = [x + self.offset for x in e.follow_up_candidates] + raise type(e)(e.message, encoding=encoding, **kwargs) + # use the original encoding, i.e. with any potential offset, to get the remaining encoding + return span, encoding[len(encoding) - len(remaining) :] + + +class LabeledSpanEncoderDecoder(GenerativeAnnotationEncoderDecoder[LabeledSpan, List[int]]): def __init__( self, - span_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]], + span_encoder_decoder: SpanEncoderDecoderWithOffset, label2id: Dict[str, int], mode: str, ): @@ -123,12 +380,209 @@ def decode( ) return result + def parse( + self, + encoding: List[int], + decoded_annotations: List[LabeledSpan], + text_length: int, + ) -> Tuple[LabeledSpan, List[int]]: + if not self.span_encoder_decoder.allow_nested: + # if we have a generated a beginning of a previous span, we need to generate the exact ending next, + # thus we set the follow-up candidates to the ending (label or end index) of the previous span + previous_encodings = [self.encode(ann) for ann in decoded_annotations] + previous_to_follow_up = {tuple(enc[:2]): enc[2] for enc in previous_encodings} + if tuple(encoding) in previous_to_follow_up: + raise IncompleteEncodingException( + "the encoding has not enough values to decode as LabeledSpan", + encoding=encoding, + follow_up_candidates=[previous_to_follow_up[tuple(encoding)]], + ) + + if self.mode == "label_indices": + label, remaining = _parse_label( + encoding, id2label=self.id2label, annotation_type=LabeledSpan + ) + elif self.mode == "indices_label": + label, remaining = None, encoding + else: + raise ValueError(f"unknown mode: {self.mode}") + + span, remaining = self.span_encoder_decoder.parse( + encoding=remaining, decoded_annotations=decoded_annotations, text_length=text_length + ) + if label is None: + label, remaining = _parse_label( + remaining, id2label=self.id2label, annotation_type=LabeledSpan + ) -class BinaryRelationEncoderDecoder(AnnotationEncoderDecoder[BinaryRelation, List[int]]): + result = LabeledSpan(start=span.start, end=span.end, label=label) + if not self.span_encoder_decoder.allow_nested: + # if we have parsed a span that has the same start and end as a previous span, + # it is not allowed to have a different label + previous_spans_to_label = { + Span(start=ann.start, end=ann.end): ann.label for ann in decoded_annotations + } + if span in previous_spans_to_label and previous_spans_to_label[span] != label: + raise DecodingSpanOverlapException( + f"the encoded span {result} overlaps with another span with a different label: " + f"{previous_spans_to_label[span]}", + encoding=encoding, + remaining=remaining, + ) + + return result, remaining + + +class LabeledMultiSpanEncoderDecoder( + GenerativeAnnotationEncoderDecoder[LabeledMultiSpan, List[int]] +): + """An encoder-decoder for LabeledMultiSpans. + + To encode a LabeledMultiSpan, the slices (start-end-index-tuples) are encoded in order, + followed by the label id. Note that we expect the MultiSpan to have at least one slice. + """ + + def __init__( + self, + span_encoder_decoder: SpanEncoderDecoderWithOffset, + label2id: Dict[str, int], + ): + self.span_encoder_decoder = span_encoder_decoder + self.label2id = label2id + self.id2label = {idx: label for label, idx in self.label2id.items()} + + def encode( + self, annotation: LabeledMultiSpan, metadata: Optional[Dict[str, Any]] = None + ) -> List[int]: + if len(annotation.slices) == 0: + raise EncodingEmptySlicesException( + "LabeledMultiSpan must have at least one slice to encode it.", + annotation=annotation, + ) + encoding = [] + for start, end in annotation.slices: + encoded_span = self.span_encoder_decoder.encode( + annotation=Span(start=start, end=end), metadata=metadata + ) + encoding.extend(encoded_span) + encoding.append(self.label2id[annotation.label]) + return encoding + + def decode( + self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None + ) -> LabeledMultiSpan: + if len(encoding) % 2 != 1: + raise DecodingLengthException( + f"an odd number of encoding entries is required for decoding a LabeledMultiSpan, " + f"but got {len(encoding)}", + encoding=encoding, + ) + slices = [] + for i in range(0, len(encoding) - 1, 2): + encoded_span = encoding[i : i + 2] + span = self.span_encoder_decoder.decode(encoding=encoded_span, metadata=metadata) + slices.append((span.start, span.end)) + label = self.id2label[encoding[-1]] + return LabeledMultiSpan(slices=tuple(slices), label=label) + + def parse( + self, + encoding: List[int], + decoded_annotations: List[LabeledMultiSpan], + text_length: int, + ) -> Tuple[LabeledMultiSpan, List[int]]: + decoded_spans = [] + decoded_slices_to_spans = defaultdict(list) + for ann in decoded_annotations: + for start, end in ann.slices: + span = Span(start=start, end=end) + decoded_spans.append(span) + decoded_slices_to_spans[(start, end)].append(ann) + try: + slices: List[Tuple[int, int]] = [] + remaining = encoding + while True: + try: + span, remaining = self.span_encoder_decoder.parse( + encoding=remaining, + decoded_annotations=decoded_spans, + text_length=text_length, + ) + except IncompleteEncodingException as e: + # if the current remaining encoding was empty, but we already have slices, + # we need to add the label ids to the follow-up candidates + if len(remaining) == 0 and len(slices) > 0: + raise IncompleteEncodingException( + "the encoding has not enough values to decode as LabeledMultiSpan", + encoding=encoding, + follow_up_candidates=sorted( + e.follow_up_candidates + list(self.id2label) + ), + ) + # otherwise (partial span or empty encoding), we just re-raise the exception + else: + raise e + slices.append((span.start, span.end)) + decoded_spans.append(span) + if len(remaining) > 0 and remaining[0] in self.id2label: + label = self.id2label[remaining[0]] + remaining = remaining[1:] + break + + except IncompleteEncodingException as e: + if not self.span_encoder_decoder.allow_nested and len(encoding) > 0: + # if we have a generated a beginning of a previous span, we need to generate the exact ending next, + # thus we set the follow-up candidates to the continuation of the previous span + previous_encodings = [self.encode(ann) for ann in decoded_annotations] + previous_to_follow_up = { + tuple(enc[: len(encoding)]): enc[len(encoding)] + for enc in previous_encodings + if len(enc) > len(encoding) + } + if tuple(encoding) in previous_to_follow_up: + raise IncompleteEncodingException( + "the encoding has not enough values to decode as LabeledMultiSpan", + encoding=encoding, + follow_up_candidates=[previous_to_follow_up[tuple(encoding)]], + ) + raise e + + result = LabeledMultiSpan(slices=tuple(slices), label=label) + + # check for any overlap with previously decoded spans + for s in slices: + if s in decoded_slices_to_spans: + # get all previous LabeledMultiSpans that have any of the slices in common + previous_spans = decoded_slices_to_spans[s] + for previous_span in previous_spans: + if previous_span != result: + if previous_span.slices != result.slices: + raise DecodingSpanOverlapException( + "the decoded slices partly overlap with the slices of another LabeledMultiSpan", + encoding=encoding, + remaining=remaining, + ) + elif ( + not self.span_encoder_decoder.allow_nested + and previous_span.label != result.label + ): + raise DecodingSpanNestedException( + "the decoded LabeledMultiSpan is nested in another LabeledMultiSpan " + "with a different label", + encoding=encoding, + remaining=remaining, + ) + + return result, remaining + + +class BinaryRelationEncoderDecoder( + GenerativeAnnotationEncoderDecoderWithParseWithErrors[BinaryRelation] +): def __init__( self, - head_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]], - tail_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]], + head_encoder_decoder: GenerativeAnnotationEncoderDecoder[Annotation, List[int]], + tail_encoder_decoder: GenerativeAnnotationEncoderDecoder[Annotation, List[int]], label2id: Dict[str, int], mode: str, loop_dummy_relation_name: Optional[str] = None, @@ -241,3 +695,95 @@ def decode( rel = BinaryRelation(head=head, tail=tail, label=label) return rel + + def parse( + self, + encoding: List[int], + decoded_annotations: List[BinaryRelation], + text_length: int, + ) -> Tuple[BinaryRelation, List[int]]: + if self.mode.endswith("_label"): + label, remaining = None, encoding + argument_mode = self.mode[: -len("_label")] + elif self.mode.startswith("label_"): + label, remaining = _parse_label( + encoding, id2label=self.id2label, annotation_type=BinaryRelation + ) + argument_mode = self.mode[len("label_") :] + else: + raise ValueError(f"unknown mode: {self.mode}") + if argument_mode == "head_tail": + first_argument_encoder = self.head_encoder_decoder + second_argument_encoder = self.tail_encoder_decoder + elif argument_mode == "tail_head": + first_argument_encoder = self.tail_encoder_decoder + second_argument_encoder = self.head_encoder_decoder + else: + raise ValueError(f"unknown argument mode: {argument_mode}") + + decoded_arguments = [] + for rel in decoded_annotations: + decoded_arguments.append(rel.head) + decoded_arguments.append(rel.tail) + + first_argument, remaining = first_argument_encoder.parse( + encoding=remaining, decoded_annotations=decoded_arguments, text_length=text_length + ) + decoded_arguments.append(first_argument) + found_none = False + try: + second_argument, remaining = second_argument_encoder.parse( + encoding=remaining, decoded_annotations=decoded_arguments, text_length=text_length + ) + except DecodingException as e: + if self.none_label is not None: + none_id = self.label2id[self.none_label] + if remaining[0:3] == [none_id] * 3: + second_argument = first_argument + remaining = remaining[3:] + found_none = True + elif len(remaining) == 0 and isinstance(e, IncompleteEncodingException): + raise IncompleteEncodingException( + "the encoding has not enough values to decode as BinaryRelation", + encoding=encoding, + follow_up_candidates=sorted(e.follow_up_candidates + [none_id]), + ) + elif 0 < len(remaining) < 3 and remaining == [none_id] * len(remaining): + raise IncompleteEncodingException( + "the encoding has not enough values to decode as BinaryRelation", + encoding=encoding, + follow_up_candidates=[none_id], + ) + else: + raise e + else: + raise e + + if label is None: + if found_none: + id2label = { + id: label for label, id in self.label2id.items() if label == self.none_label + } + else: + id2label = { + id: label for label, id in self.label2id.items() if label != self.none_label + } + label, remaining = _parse_label( + remaining, id2label=id2label, annotation_type=BinaryRelation + ) + + if label == self.none_label: + if self.loop_dummy_relation_name is None: + raise ValueError( + f"loop_dummy_relation_name is not set, but none_label={self.none_label} " + f"was found in the encoding: {encoding} (label2id: {self.label2id}))" + ) + label = self.loop_dummy_relation_name + + if argument_mode == "head_tail": + rel = BinaryRelation(head=first_argument, tail=second_argument, label=label) + elif argument_mode == "tail_head": + rel = BinaryRelation(head=second_argument, tail=first_argument, label=label) + else: + raise ValueError(f"unknown argument mode: {argument_mode}") + return rel, remaining diff --git a/src/pie_modules/taskmodules/pointer_network/logits_processor.py b/src/pie_modules/taskmodules/pointer_network/logits_processor.py index 04ae0e7d3..8024b5cf5 100644 --- a/src/pie_modules/taskmodules/pointer_network/logits_processor.py +++ b/src/pie_modules/taskmodules/pointer_network/logits_processor.py @@ -41,4 +41,7 @@ def __call__( self._prefix_allowed_tokens_fn(batch_id, sent, mask.size(1)), ] = 0 - return scores + mask + # It may happen that all valid candidates have a score of -inf. Since we still want + # only valid candidates, we replace all -inf scores with a very small number. + scores_finite = torch.nan_to_num(scores) + return scores_finite + mask diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index 2042f0ec4..bf67c0519 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -1,11 +1,14 @@ import dataclasses import json import logging +import time from collections import Counter, defaultdict from functools import cmp_to_key from typing import ( Any, Dict, + Generic, + Hashable, Iterable, Iterator, List, @@ -14,6 +17,7 @@ Set, Tuple, Type, + TypeVar, Union, ) @@ -34,9 +38,11 @@ # import for backwards compatibility (don't remove!) from pie_modules.documents import ( + TextDocumentWithLabeledMultiSpans, TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) +from ..annotations import LabeledMultiSpan from ..document.processing import token_based_document_to_text_based, tokenize_document from ..utils import resolve_type from .common import BatchableMixin, DecodingException, get_first_occurrence_index @@ -46,6 +52,8 @@ ) from .pointer_network.annotation_encoder_decoder import ( BinaryRelationEncoderDecoder, + IncompleteEncodingException, + LabeledMultiSpanEncoderDecoder, LabeledSpanEncoderDecoder, SpanEncoderDecoderWithOffset, ) @@ -83,15 +91,74 @@ def decoder_attention_mask(self) -> List[int]: ] TaskOutputType: TypeAlias = LabelsAndOptionalConstraints -KEY_INVALID_CORRECT = "correct" +def span_sort_key(span: Annotation) -> Tuple[int, ...]: + # TODO: use the full span to sort + # just use the (first) start index to sort + if isinstance(span, LabeledSpan): + return (span.start,) + elif isinstance(span, LabeledMultiSpan): + if len(span.slices) == 0: + raise Exception(f"can not sort LabeledMultiSpan with empty slices: {span}") + return (span.slices[0][0],) + else: + raise Exception(f"unexpected type: {type(span)}") -def cmp_src_rel(v1: BinaryRelation, v2: BinaryRelation) -> int: - if not all(isinstance(ann, LabeledSpan) for ann in [v1.head, v1.tail, v2.head, v2.tail]): - raise Exception(f"expected LabeledSpan, but got: {v1}, {v2}") - if v1.head.start == v2.head.start: # v1[0]["from"] == v2[0]["from"]: - return v1.tail.start - v2.tail.start # v1[1]["from"] - v2[1]["from"] - return v1.head.start - v2.head.start # v1[0]["from"] - v2[0]["from"] + +def binary_relation_sort_key(rel: BinaryRelation) -> Tuple[int, ...]: + # use the start indices of head and tail to sort + return span_sort_key(rel.head) + span_sort_key(rel.tail) + + +def annotation_to_indices(argument_annotation: Annotation) -> Tuple[int, ...]: + if isinstance(argument_annotation, LabeledSpan): + return argument_annotation.start, argument_annotation.end + elif isinstance(argument_annotation, LabeledMultiSpan): + result: List[int] = [] + for s in argument_annotation.slices: + result.extend(s) + return tuple(result) + else: + raise Exception(f"unexpected type: {type(argument_annotation)}") + + +def annotation_to_label(annotation: Annotation) -> str: + if isinstance(annotation, (LabeledSpan, LabeledMultiSpan)): + return annotation.label + else: + raise Exception(f"unexpected type: {type(annotation)}") + + +V = TypeVar("V") + + +class SimpleCache(Generic[V]): + def __init__(self, max_size: int): + self.cache: Dict[Hashable, V] = dict() + self.time_added: Dict[Hashable, int] = dict() + self.max_size = max_size + + def prune(self) -> None: + if len(self.cache) > self.max_size: + # remove the least recently added item + oldest_key = min(self.time_added, key=self.time_added.get) # type: ignore + del self.cache[oldest_key] + del self.time_added[oldest_key] + + def add(self, key: Hashable, value: V) -> None: + self.cache[key] = value + # update last access with the current time in ms + self.time_added[key] = time.time_ns() + self.prune() + + def get(self, key: Hashable) -> V: + return self.cache[key] + + def __len__(self) -> int: + return len(self.cache) + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache @TaskModule.register() @@ -117,6 +184,7 @@ def __init__( none_label: str = "none", loop_dummy_relation_name: str = "loop", constrained_generation: bool = False, + constrain_with_previous_records: bool = True, # generic pointer network label_tokens: Optional[Dict[str, str]] = None, label_representations: Optional[Dict[str, str]] = None, @@ -170,8 +238,9 @@ def __init__( self.none_label = none_label self.loop_dummy_relation_name = loop_dummy_relation_name self.constrained_generation = constrained_generation + self.constrain_with_previous_records = constrain_with_previous_records # will be set in _post_prepare() - self.relation_encoder_decoder: BinaryRelationEncoderDecoder + self.annotation_encoder_decoder: BinaryRelationEncoderDecoder # collected in prepare(), if not passed in self.labels_per_layer = labels_per_layer @@ -203,6 +272,12 @@ def __init__( # logging self.log_first_n_examples = log_first_n_examples + # cache + self.cache_decoded: SimpleCache[Tuple[List[BinaryRelation], List[int]]] = SimpleCache( + # TODO: set max_size to a reasonable value when using the cache is fixed + max_size=0 + ) + @property def document_type(self) -> Type[TextBasedDocument]: return self._document_type @@ -253,15 +328,25 @@ def _prefix_allowed_tokens_fn_with_maximum( )[0].labels else: unpadded_label_ids = [] - _, _, remaining = self.decode_relations(label_ids=unpadded_label_ids) - # this is a binary mask - constraint = self._build_constraint( - previous_ids=remaining, input_len=maximum - self.pointer_offset - ) - # convert to indices - allowed_indices = torch.nonzero(constraint).squeeze(1) - # convert to a list - return allowed_indices.tolist() + + try: + follow_up_candidates = self.get_follow_up_candidates( + previous_ids=unpadded_label_ids, input_len=maximum - self.pointer_offset + ) + except DecodingException as e: + # if the decoding failed, allow all tokens. Maybe the model can recover from this state + # TODO: remove the warning? + logger.warning(f"failed to get follow_up_candidates: {e}, allow all tokens") + return list(range(maximum)) + + # If there is only one candidate, we add the eos token. This is because two ids are sampled + # when using GenerationMixin.beam_search() and we want to avoid that a non-candidate which + # is "more wrong" is sampled. + if len(follow_up_candidates) == 1: + follow_up_candidates.add(self.eos_id) + + # sort and convert to a list + return sorted(follow_up_candidates) def _prepare(self, documents: Sequence[DocumentType]) -> None: # collect all labels @@ -323,17 +408,29 @@ def _post_prepare(self) -> None: offset=self.pointer_offset, exclusive_end=False ) span_labels = self.labels_per_layer[self.span_layer_name] - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=span_encoder_decoder, - # restrict label2id to get better error messages - label2id={label: idx for label, idx in self.label2id.items() if label in span_labels}, - mode="indices_label", - ) + # restrict label2id to get better error messages + span_label2id = { + label: idx for label, idx in self.label2id.items() if label in span_labels + } + labeled_span_encoder_decoder: Union[ + LabeledSpanEncoderDecoder, LabeledMultiSpanEncoderDecoder + ] + if self.use_multi_spans: + labeled_span_encoder_decoder = LabeledMultiSpanEncoderDecoder( + span_encoder_decoder=span_encoder_decoder, + label2id=span_label2id, + ) + else: + labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=span_encoder_decoder, + label2id=span_label2id, + mode="indices_label", + ) relation_labels = self.labels_per_layer[self.relation_layer_name] + [ self.loop_dummy_relation_name, self.none_label, ] - self.relation_encoder_decoder = BinaryRelationEncoderDecoder( + self.annotation_encoder_decoder = BinaryRelationEncoderDecoder( head_encoder_decoder=labeled_span_encoder_decoder, tail_encoder_decoder=labeled_span_encoder_decoder, # restrict label2id to get better error messages @@ -402,6 +499,10 @@ def pointer_offset(self) -> int: def target_ids(self) -> Set[int]: return set(range(self.pointer_offset)) + @property + def use_multi_spans(self) -> bool: + return issubclass(self.document_type, TextDocumentWithLabeledMultiSpans) + def configure_model_metric(self, stage: Optional[str] = None) -> Optional[Metric]: layer_metrics = { layer_name: PrecisionRecallAndF1ForLabeledAnnotations() @@ -412,40 +513,12 @@ def configure_model_metric(self, stage: Optional[str] = None) -> Optional[Metric unbatch_function=self.unbatch_output, decode_layers_with_errors_function=self.decode_annotations, layer_metrics=layer_metrics, - error_key_correct=KEY_INVALID_CORRECT, + error_key_correct=self.annotation_encoder_decoder.KEY_INVALID_CORRECT, ) - def decode_relations( - self, - label_ids: List[int], - ) -> Tuple[List[BinaryRelation], Dict[str, int], List[int]]: - errors: Dict[str, int] = defaultdict(int) - encodings = [] - current_encoding: List[int] = [] - valid_encoding: BinaryRelation - if len(label_ids): - for i in label_ids: - current_encoding.append(i) - # An encoding is complete when it ends with a relation_id - # or when it contains a none_id and has a length of 7 - if i in self.relation_ids or (i == self.none_id and len(current_encoding) == 7): - # try to decode the current relation encoding - try: - valid_encoding = self.relation_encoder_decoder.decode( - encoding=current_encoding - ) - encodings.append(valid_encoding) - errors[KEY_INVALID_CORRECT] += 1 - except DecodingException as e: - errors[e.identifier] += 1 - - current_encoding = [] - - return encodings, dict(errors), current_encoding - - def encode_annotations( - self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None - ) -> TaskOutputType: + def prepare_annotations_for_encoding( + self, layers: Dict[str, List[Annotation]] + ) -> List[BinaryRelation]: if not set(layers.keys()) == set(self.layer_names): raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}") @@ -454,17 +527,15 @@ def encode_annotations( # encode relations all_relation_arguments = set() - relation_encodings = dict() + prepared_relations = [] for rel in layers[self.relation_layer_name]: if not isinstance(rel, BinaryRelation): raise Exception(f"expected BinaryRelation, but got: {rel}") if rel.label in self.labels_per_layer[self.relation_layer_name]: - encoded_relation = self.relation_encoder_decoder.encode( - annotation=rel, metadata=metadata - ) + encoded_relation = self.annotation_encoder_decoder.encode(annotation=rel) if encoded_relation is None: raise Exception(f"failed to encode relation: {rel}") - relation_encodings[rel] = encoded_relation + prepared_relations.append(rel) all_relation_arguments.update([rel.head, rel.tail]) # encode spans that are not arguments of any relation @@ -475,29 +546,40 @@ def encode_annotations( dummy_relation = BinaryRelation( head=span, tail=span, label=self.loop_dummy_relation_name ) - encoded_relation = self.relation_encoder_decoder.encode( - annotation=dummy_relation, metadata=metadata - ) + encoded_relation = self.annotation_encoder_decoder.encode(annotation=dummy_relation) if encoded_relation is not None: - relation_encodings[dummy_relation] = encoded_relation + prepared_relations.append(dummy_relation) - # sort relations by start indices of head and tail # TODO: is this correct? - sorted_relations = sorted(relation_encodings, key=cmp_to_key(cmp_src_rel)) + # sort relations by start indices of head and tail + sorted_relations = sorted(prepared_relations, key=binary_relation_sort_key) + return sorted_relations + + def encode_annotations( + self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None + ) -> TaskOutputType: + prepared_annotations = self.prepare_annotations_for_encoding(layers=layers) # build target_ids target_ids = [] - for rel in sorted_relations: - encoded_relation = relation_encodings[rel] - target_ids.extend(encoded_relation) + for rel in prepared_annotations: + encoded_annotation = self.annotation_encoder_decoder.encode(annotation=rel) + target_ids.extend(encoded_annotation) target_ids.append(self.eos_id) + if self.create_constraints: + if metadata is None or "src_len" not in metadata: + raise Exception("metadata with 'src_len' is required to create constraints") + constraints = self.build_constraints( + input_len=metadata["src_len"], target_ids=target_ids + ).tolist() + else: + constraints = None + + result = LabelsAndOptionalConstraints(labels=target_ids, constraints=constraints) + # sanity check - _, encoding_errors, remaining = self.decode_relations(label_ids=target_ids) - if ( - not all(v == 0 for k, v in encoding_errors.items() if k != "correct") - or len(remaining) > 0 - ): - decoded, invalid = self.decode_annotations(LabelsAndOptionalConstraints(target_ids)) + decoded, decoding_errors = self.decode_annotations(encoding=result) + if not all(v == 0 for k, v in decoding_errors.items() if k != "correct"): not_encoded = {} for layer_name in layers: # convert to dicts to make them comparable (original annotations are attached which breaks comparison) @@ -510,50 +592,49 @@ def encode_annotations( not_encoded[layer_name] = list(filtered) if len(not_encoded) > 0: logger.warning( - f"encoding errors: {encoding_errors}, skipped annotations:\n" + f"encoding errors: {decoding_errors}, skipped annotations:\n" f"{json.dumps(not_encoded, sort_keys=True, indent=2)}" ) - elif len([tag for tag in remaining if tag != self.eos_id]) > 0: - logger.warning( - f"encoding errors: {encoding_errors}, remaining encoding ids: {remaining}" - ) - if self.create_constraints: - if metadata is None or "src_len" not in metadata: - raise Exception("metadata with 'src_len' is required to create constraints") - constraints = self.build_constraints( - input_len=metadata["src_len"], target_ids=target_ids - ).tolist() - else: - constraints = None - return LabelsAndOptionalConstraints(labels=target_ids, constraints=constraints) + return result - def decode_annotations( - self, encoding: TaskOutputType - ) -> Tuple[Dict[str, Iterable[Annotation]], Dict[str, int]]: - decoded_relations, errors, remaining = self.decode_relations(label_ids=encoding.labels) - relation_tuples: List[Tuple[Tuple[int, int], Tuple[int, int], str]] = [] - entity_labels: Dict[Tuple[int, int], List[str]] = defaultdict(list) - for rel in decoded_relations: - head_span = (rel.head.start, rel.head.end) - entity_labels[head_span].append(rel.head.label) + def postprocess_decoded_annotations( + self, decoded_annotations: List[Annotation] + ) -> Dict[str, Iterable[Annotation]]: + relation_tuples: List[Tuple[Tuple[int, ...], Tuple[int, ...], str]] = [] + entity_labels: Dict[Tuple[int, ...], List[str]] = defaultdict(list) + for rel in decoded_annotations: + if not isinstance(rel, BinaryRelation): + raise Exception(f"expected BinaryRelation, but got: {rel}") + head_indices = annotation_to_indices(rel.head) + head_label = annotation_to_label(rel.head) + entity_labels[head_indices].append(head_label) if rel.label != self.loop_dummy_relation_name: - tail_span = (rel.tail.start, rel.tail.end) - entity_labels[tail_span].append(rel.tail.label) - relation_tuples.append((head_span, tail_span, rel.label)) + tail_indices = annotation_to_indices(rel.tail) + tail_label = annotation_to_label(rel.tail) + entity_labels[tail_indices].append(tail_label) + relation_tuples.append((head_indices, tail_indices, rel.label)) else: assert rel.head == rel.tail # It may happen that some spans take part in multiple relations, but got generated with different labels. # In this case, we just create one span and take the most common label. - entities: Dict[Tuple[int, int], LabeledSpan] = {} - for (start, end), labels in entity_labels.items(): + entities: Dict[Tuple[int, ...], Union[LabeledSpan, LabeledMultiSpan]] = {} + for span_indices, labels in entity_labels.items(): c = Counter(labels) # if len(c) > 1: # logger.warning(f"multiple labels for span, take the most common: {dict(c)}") most_common_label = c.most_common(1)[0][0] - entities[(start, end)] = LabeledSpan(start=start, end=end, label=most_common_label) + if self.use_multi_spans: + slices = tuple( + (span_indices[i], span_indices[i + 1]) for i in range(0, len(span_indices), 2) + ) + entities[span_indices] = LabeledMultiSpan(slices=slices, label=most_common_label) + else: + entities[span_indices] = LabeledSpan( + start=span_indices[0], end=span_indices[1], label=most_common_label + ) entity_layer = list(entities.values()) relation_layer = [ @@ -563,75 +644,75 @@ def decode_annotations( return { self.span_layer_name: entity_layer, self.relation_layer_name: relation_layer, - }, errors + } - def _build_constraint( - self, - previous_ids: List[int], - input_len: int, + def decode_annotations( + self, encoding: TaskOutputType + ) -> Tuple[Dict[str, Iterable[Annotation]], Dict[str, int]]: + try: + ( + decoded_annotations, + errors, + remaining, + ) = self.annotation_encoder_decoder.parse_with_error_handling( + encoding=encoding.labels, + input_length=self.tokenizer.model_max_length, + stop_ids=[self.eos_id], + disrespect_decoded_annotations=not self.constrain_with_previous_records, + ) + return self.postprocess_decoded_annotations(decoded_annotations), errors + except Exception as e: + logger.error(f"failed to decode annotations: {e}") + return {layer_name: [] for layer_name in self.layer_names}, {"full_encoding": 1} + + def follow_up_candidates_to_mask( + self, follow_up_candidates: Set[int], input_len: int ) -> torch.LongTensor: - result: torch.LongTensor = torch.zeros(input_len + self.pointer_offset, dtype=torch.int64) + result = torch.zeros(input_len + self.pointer_offset).to(torch.long) + result[list(follow_up_candidates)] = 1 + return result + + def get_follow_up_candidates(self, previous_ids: List[int], input_len: int) -> Set[int]: + # if the eos was already generated, do not allow any other token if self.eos_id in previous_ids: - # once eos is predicted, only allow padding - result[self.target_pad_id] = 1 - return result - contains_none = self.none_id in previous_ids - idx = len(previous_ids) - if idx == 0: # [] -> first span start or eos - # Allow all offsets ... - result[self.pointer_offset :] = 1 - # ... and the eos token. - result[self.eos_id] = 1 - elif idx == 1: # [14] -> first span end - # Allow all offsets greater than the span start. - span_start = previous_ids[-1] - result[span_start:] = 1 - elif idx == 2: # [14,14] -> first span label - # Allow only span ids. - result[self.span_ids] = 1 - elif idx == 3: # [14,14,s1] -> second span start or none - # Allow all offsets ... - result[self.pointer_offset :] = 1 - # ... and the none token (for single spans). - result[self.none_id] = 1 - # But exclude offsets covered by the first span. - first_span_start = previous_ids[0] - first_span_end = previous_ids[1] + 1 - result[first_span_start:first_span_end] = 0 - elif idx == 4: # [14,14,s1,23] -> second span end or none - # if we have a none label, allow only none - if contains_none: - result[self.none_id] = 1 - else: - # Allow all offsets after the second span start ... - second_span_start = previous_ids[-1] - result[second_span_start:] = 1 - # ... but exclude offsets covered by the first span. - first_span_start = previous_ids[0] - first_span_end = previous_ids[1] + 1 - result[first_span_start:first_span_end] = 0 - # Mitigate overlap of first and second span: - # if first span is after the second span, - # disallow all offsets after the first span end - if first_span_start > second_span_start: - result[first_span_end:] = 0 - elif idx == 5: # [14,14,s1,23,25] -> second span label or none - # if we have a none label, allow only none - if contains_none: - result[self.none_id] = 1 - else: - # allow only span ids - result[self.span_ids] = 1 - elif idx == 6: # [14,14,s1,23,25,s2] -> relation label or none - # if we have a none label, allow only none - if contains_none: - result[self.none_id] = 1 - else: - # allow only relation ids - result[self.relation_ids] = 1 + return {self.eos_id} + + # speed up by using a cache + cache_key = tuple(previous_ids[:-1]) + if cache_key in self.cache_decoded: + decoded_annotations, previous_successfully_decoded = self.cache_decoded.get(cache_key) + encoding = previous_ids[len(previous_successfully_decoded) :] else: - # any longer sequence can only be completed with padding - result[self.target_pad_id] = 1 + decoded_annotations = None + encoding = previous_ids + ( + decoded_annotations, + decoding_errors, + remaining, + ) = self.annotation_encoder_decoder.parse_with_error_handling( + encoding=encoding, + input_length=input_len, + stop_ids=[self.eos_id], + decoded_annotations=decoded_annotations, + disrespect_decoded_annotations=not self.constrain_with_previous_records, + ) + successfully_decoded = previous_ids[: len(previous_ids) - len(remaining)] + self.cache_decoded.add(tuple(previous_ids), (decoded_annotations, successfully_decoded)) + try: + self.annotation_encoder_decoder.parse( + encoding=remaining, decoded_annotations=decoded_annotations, text_length=input_len + ) + raise Exception("expected IncompleteEncodingException") + except IncompleteEncodingException as e: + result = set(e.follow_up_candidates) + + # if the encoding could be parsed completely, also allow the eos token + if len(remaining) == 0: + result.add(self.eos_id) + + if len(result) == 0: + raise Exception(f"no follow_up_candidates found: {previous_ids}") + return result def build_constraints( @@ -639,50 +720,25 @@ def build_constraints( input_len: int, target_ids: List[int], ) -> torch.LongTensor: - if not ( - isinstance(self.relation_encoder_decoder, BinaryRelationEncoderDecoder) - and self.relation_encoder_decoder.mode == "tail_head_label" - and isinstance( - self.relation_encoder_decoder.head_encoder_decoder, LabeledSpanEncoderDecoder - ) - and self.relation_encoder_decoder.head_encoder_decoder.mode == "indices_label" - and isinstance( - self.relation_encoder_decoder.head_encoder_decoder.span_encoder_decoder, - SpanEncoderDecoderWithOffset, - ) - and self.relation_encoder_decoder.head_encoder_decoder.span_encoder_decoder.offset - == self.pointer_offset - and not self.relation_encoder_decoder.head_encoder_decoder.span_encoder_decoder.exclusive_end - and self.relation_encoder_decoder.head_encoder_decoder - == self.relation_encoder_decoder.tail_encoder_decoder - ): - raise Exception( - "build_constraints() is only supported for BinaryRelationEncoderDecoder with mode 'tail_head_label' and LabeledSpanEncoderDecoder as (head|tail)_encoder_decoder with mode 'indices_label'" - ) if target_ids[-1] != self.eos_id: raise Exception( f"expected eos_id [{self.eos_id}] at the end of target_ids: {target_ids}" ) labels_without_eos = target_ids[:-1] - if len(labels_without_eos) % 7 != 0: - raise Exception( - f"expected the number of labels_without_eos to be a multiple of 7: {target_ids}" - ) constraints: List[torch.LongTensor] = [] for idx, t in enumerate(labels_without_eos): - current_tuple_start = (idx // 7) * 7 - current_tuple = target_ids[current_tuple_start:idx] - current_constraints = self._build_constraint( - previous_ids=current_tuple, input_len=input_len + follow_up_candidates = self.get_follow_up_candidates( + previous_ids=labels_without_eos[:idx], input_len=input_len + ) + current_constraints = self.follow_up_candidates_to_mask( + follow_up_candidates=follow_up_candidates, input_len=input_len ) if current_constraints[t] == 0: raise Exception( f"current_constraints[{t}] is 0, but should be 1: {current_constraints}" ) constraints.append(current_constraints) - eos_constraint: torch.LongTensor = torch.zeros( - input_len + self.pointer_offset, dtype=torch.int64 - ) + eos_constraint = torch.zeros(input_len + self.pointer_offset).to(torch.long) eos_constraint[self.eos_id] = 1 constraints.append(eos_constraint) result: torch.LongTensor = torch.stack(constraints) diff --git a/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py b/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py index 4e7eda917..f85d043a6 100644 --- a/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py +++ b/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py @@ -1,15 +1,25 @@ import pytest from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span +from pie_modules.annotations import LabeledMultiSpan from pie_modules.taskmodules.pointer_network.annotation_encoder_decoder import ( BinaryRelationEncoderDecoder, + DecodingEmptySpanException, DecodingLabelException, DecodingLengthException, DecodingNegativeIndexException, DecodingOrderException, + DecodingSpanNestedException, + DecodingSpanOverlapException, + EncodingEmptySlicesException, + EncodingEmptySpanException, + IncompleteEncodingException, + LabeledMultiSpanEncoderDecoder, LabeledSpanEncoderDecoder, SpanEncoderDecoder, SpanEncoderDecoderWithOffset, + spans_are_nested, + spans_have_overlap, ) @@ -26,6 +36,16 @@ def test_span_encoder_decoder(exclusive_end): assert encoder_decoder.decode([1, 1]) == Span(start=1, end=2) +def test_span_encoder_decoder_empty_span(): + encoder_decoder = SpanEncoderDecoder() + with pytest.raises(EncodingEmptySpanException) as excinfo: + encoder_decoder.encode(Span(start=1, end=1)) + assert ( + str(excinfo.value) + == "can not encode empty Span annotations, i.e. where the start index equals the end index" + ) + + def test_span_encoder_decoder_wrong_length(): """Test the SimpleSpanEncoderDecoder class.""" @@ -73,7 +93,240 @@ def test_span_encoder_decoder_wrong_offset(): with pytest.raises(DecodingNegativeIndexException) as excinfo: encoder_decoder.decode([-1, 2]) assert str(excinfo.value) == "indices must be positive, but got: [-1, 2]" - assert excinfo.value.identifier == "index" + assert excinfo.value.identifier == "negative_index" + + +@pytest.mark.parametrize("exclusive_end", [True, False]) +def test_span_encoder_decoder_parse(exclusive_end): + encoder_decoder = SpanEncoderDecoder(exclusive_end) + if exclusive_end: + assert encoder_decoder.parse([1, 2, 3, 4], [], 5) == (Span(start=1, end=2), [3, 4]) + else: + assert encoder_decoder.parse([1, 1, 3, 4], [], 5) == (Span(start=1, end=2), [3, 4]) + + +@pytest.mark.parametrize("exclusive_end", [True, False]) +def test_span_encoder_decoder_parse_empty_span(exclusive_end): + encoder_decoder = SpanEncoderDecoder(exclusive_end) + encoding = [3, 3, 3, 4] if exclusive_end else [3, 2, 3, 4] + with pytest.raises(DecodingEmptySpanException) as excinfo: + encoder_decoder.parse(encoding, [], 5) + assert excinfo.value.identifier == "empty_span" + assert ( + str(excinfo.value) + == "end index can not be equal to start index to decode as Span, but got: start=3, end=3" + ) + assert excinfo.value.remaining == [3, 4] + + +@pytest.mark.parametrize("exclusive_end", [True, False]) +def test_span_encoder_decoder_parse_wrong_order(exclusive_end): + encoder_decoder = SpanEncoderDecoder(exclusive_end) + encoding = [3, 2, 3, 4] if exclusive_end else [3, 1, 3, 4] + with pytest.raises(DecodingOrderException) as excinfo: + encoder_decoder.parse(encoding, [], 5) + assert excinfo.value.identifier == "order" + assert ( + str(excinfo.value) + == "end index can not be smaller than start index, but got: start=3, end=2" + ) + assert excinfo.value.remaining == [3, 4] + + +@pytest.mark.parametrize("exclusive_end", [True, False]) +def test_span_encoder_decoder_parse_negative_index(exclusive_end): + encoder_decoder = SpanEncoderDecoder(exclusive_end) + encoding = [-1, 2, 3, 4] if exclusive_end else [-1, 1, 3, 4] + with pytest.raises(DecodingNegativeIndexException) as excinfo: + encoder_decoder.parse(encoding, [], 5) + assert excinfo.value.identifier == "negative_index" + assert str(excinfo.value) == "indices must be positive, but got: start=-1, end=2" + assert excinfo.value.remaining == [3, 4] + + +def test_spans_are_nested(): + # fully nested + assert spans_are_nested(Span(start=1, end=4), Span(start=2, end=3)) + assert spans_are_nested(Span(start=2, end=3), Span(start=1, end=4)) + # nested with same start + assert spans_are_nested(Span(start=1, end=3), Span(start=1, end=2)) + assert spans_are_nested(Span(start=1, end=2), Span(start=1, end=3)) + # nested with same end + assert spans_are_nested(Span(start=2, end=4), Span(start=3, end=4)) + assert spans_are_nested(Span(start=3, end=4), Span(start=2, end=4)) + # nested with same start and end + assert spans_are_nested(Span(start=1, end=3), Span(start=1, end=3)) + + # not nested + assert not spans_are_nested(Span(start=1, end=2), Span(start=3, end=4)) + # not nested, but touching + assert not spans_are_nested(Span(start=1, end=2), Span(start=2, end=3)) + # not nested, but overlap + assert not spans_are_nested(Span(start=1, end=3), Span(start=2, end=4)) + + +def test_spans_have_overlap(): + # overlap, no touching + assert spans_have_overlap(Span(start=1, end=3), Span(start=2, end=4)) + assert spans_have_overlap(Span(start=2, end=4), Span(start=1, end=3)) + # same start, nested -> no overlap + assert not spans_have_overlap(Span(start=1, end=2), Span(start=1, end=3)) + assert not spans_have_overlap(Span(start=1, end=3), Span(start=1, end=2)) + # same end, nested -> no overlap + assert not spans_have_overlap(Span(start=2, end=3), Span(start=1, end=3)) + assert not spans_have_overlap(Span(start=1, end=3), Span(start=2, end=3)) + # same start and end, nested -> no overlap + assert not spans_have_overlap(Span(start=1, end=3), Span(start=1, end=3)) + # no overlap, not touching + assert not spans_have_overlap(Span(start=1, end=2), Span(start=3, end=4)) + assert not spans_have_overlap(Span(start=3, end=4), Span(start=1, end=2)) + # no overlap, touching + assert not spans_have_overlap(Span(start=1, end=2), Span(start=2, end=3)) + assert not spans_have_overlap(Span(start=2, end=3), Span(start=1, end=2)) + + +@pytest.mark.parametrize("allow_nested", [True, False]) +def test_span_encoder_decoder_parse_with_previous_annotations(allow_nested): + encoder_decoder = SpanEncoderDecoder(allow_nested=allow_nested) + expected_span = Span(start=1, end=3) + remaining_encoding = [3, 4] + # encoding of the expected span + remaining encoding + encoding = encoder_decoder.encode(expected_span) + remaining_encoding + other_span = Span(start=3, end=4) + nested_span = Span(start=2, end=3) + overlapping_span = Span(start=2, end=4) + # other_span should not pose a problem in any case + assert encoder_decoder.parse(encoding, [other_span], 5) == (expected_span, remaining_encoding) + # nested_span should only be allowed if allow_nested=True + if allow_nested: + assert encoder_decoder.parse(encoding, [nested_span], 5) == ( + expected_span, + remaining_encoding, + ) + else: + with pytest.raises(DecodingSpanNestedException) as excinfo: + encoder_decoder.parse(encoding, [nested_span], 5) + assert ( + str(excinfo.value) == f"the encoded span is nested in another span: {nested_span}. " + f"You can set allow_nested=True to allow nested spans." + ) + assert excinfo.value.remaining == remaining_encoding + # overlapping_span is not allowed in any case + with pytest.raises(DecodingSpanOverlapException) as excinfo: + encoder_decoder.parse(encoding, [overlapping_span], 5) + assert str(excinfo.value) == f"the encoded span overlaps with another span: {overlapping_span}" + assert excinfo.value.remaining == remaining_encoding + + +def test_dummy(): + encoder_decoder = SpanEncoderDecoder(allow_nested=False, exclusive_end=False) + previous_span = Span(start=55, end=66) + encoding = [51, 69] + text_length = 1000 + with pytest.raises(DecodingSpanNestedException) as excinfo: + encoder_decoder.parse(encoding, [previous_span], text_length) + assert ( + str(excinfo.value) + == f"the encoded span is nested in another span: {previous_span}. You can set allow_nested=True to allow nested spans." + ) + + +@pytest.mark.parametrize( + "exclusive_end,allow_nested", [(False, False), (False, True), (True, False), (True, True)] +) +def test_span_encoder_decoder_parse_incomplete_0(exclusive_end, allow_nested): + encoder_decoder = SpanEncoderDecoder(exclusive_end=exclusive_end, allow_nested=allow_nested) + # no previous annotations + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + assert excinfo.value.follow_up_candidates == [0, 1, 2, 3, 4, 5] + # previous annotation + other_span = Span(start=2, end=4) + encoded_other_span = encoder_decoder.encode(other_span) + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [other_span], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + if allow_nested: + assert excinfo.value.follow_up_candidates == [0, 1, 2, 3, 4, 5] + else: + if exclusive_end: + assert encoded_other_span == [2, 4] + else: + assert encoded_other_span == [2, 3] + # index 3 is excluded because they are covered by the other_span + # Note that index 2 is not included, despite being covered by the other_span, + # because we allow to generate the exact same span again. + assert excinfo.value.follow_up_candidates == [0, 1, 2, 4, 5] + + +@pytest.mark.parametrize( + "allow_nested,exclusive_end", [(False, False), (False, True), (True, False), (True, True)] +) +def test_span_encoder_decoder_parse_incomplete_1(allow_nested, exclusive_end): + encoder_decoder = SpanEncoderDecoder(exclusive_end=exclusive_end, allow_nested=allow_nested) + encoding = [1] + + # no previous annotations + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding, [], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index, so the follow-up candidates ... + if exclusive_end: + # are bigger than 1, but smaller still in the range of the text length (equal or smaller than 5) + assert excinfo.value.follow_up_candidates == [2, 3, 4, 5, 6] + else: + # bigger or equal to 1, but smaller still in the range of the text length (smaller than 5) + assert excinfo.value.follow_up_candidates == [1, 2, 3, 4, 5] + + # previous annotations + # a span before the current start index should not affect the follow-up candidates + other_span_before = Span(start=0, end=1) + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding, [other_span_before], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + if exclusive_end: + assert excinfo.value.follow_up_candidates == [2, 3, 4, 5, 6] + else: + assert excinfo.value.follow_up_candidates == [1, 2, 3, 4, 5] + + # a span after the current start index should limit the follow-up candidates + other_span_after = Span(start=2, end=4) + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding, [other_span_after], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + if allow_nested: + if exclusive_end: + # 3 as end index is excluded because the resulting span [1, 3) + # would have an overlap with the other_span_after + assert excinfo.value.follow_up_candidates == [2, 4, 5, 6] + else: + # 2 as end index is excluded because the resulting span [1, 3) + # would have an overlap with the other_span_after + assert excinfo.value.follow_up_candidates == [1, 3, 4, 5] + else: + if exclusive_end: + assert excinfo.value.follow_up_candidates == [2] + else: + assert excinfo.value.follow_up_candidates == [1] + + nesting_span = Span(start=0, end=2) + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding, [nesting_span], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + if allow_nested: + if exclusive_end: + # only the span [1, 2) is allowed, so only 2 is a valid follow-up candidate + assert excinfo.value.follow_up_candidates == [2] + else: + # only the span [1, 2) is allowed, so only 1 is a valid follow-up candidate + assert excinfo.value.follow_up_candidates == [1] + else: + # We generate the exact same span again, so the follow-up candidates are the same as for the previous case + if exclusive_end: + assert excinfo.value.follow_up_candidates == [2] + else: + assert excinfo.value.follow_up_candidates == [1] def test_span_encoder_decoder_with_offset(): @@ -85,6 +338,31 @@ def test_span_encoder_decoder_with_offset(): assert encoder_decoder.decode([2, 3]) == Span(start=1, end=2) +def test_span_encoder_decoder_with_offset_parse(): + """Test the SpanEncoderDecoderWithOffset class.""" + encoder_decoder = SpanEncoderDecoderWithOffset(offset=1) + expected_span = Span(start=1, end=3) + encoded_span = encoder_decoder.encode(expected_span) + + # test without remaining encoding + assert encoder_decoder.parse(encoded_span, [], 6) == (expected_span, []) + + # test with remaining encoding + remaining_encoding = [3, 4] + assert encoder_decoder.parse(encoded_span + remaining_encoding, [], 6) == ( + expected_span, + remaining_encoding, + ) + + +def test_span_encoder_decoder_with_offset_parse_incomplete(): + encoder_decoder = SpanEncoderDecoderWithOffset(offset=1) + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([2], [], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + assert excinfo.value.follow_up_candidates == [3, 4, 5, 6, 7] + + @pytest.mark.parametrize("mode", ["indices_label", "label_indices"]) def test_labeled_span_encoder_decoder(mode): """Test the LabeledSpanEncoderDecoder class.""" @@ -145,6 +423,204 @@ def test_labeled_span_encoder_decoder_unknown_mode(): assert str(excinfo.value) == "unknown mode: unknown" +@pytest.mark.parametrize("mode", ["indices_label", "label_indices"]) +def test_labeled_span_encoder_decoder_parse(mode): + """Test the LabeledSpanEncoderDecoder class.""" + + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode=mode, + ) + expected_span = LabeledSpan(start=1, end=3, label="A") + encoded_span = encoder_decoder.encode(expected_span) + remaining_encoding = [3, 4] + # encoding of the expected span + remaining encoding + encoding = encoded_span + remaining_encoding + assert encoder_decoder.parse(encoding, [], 6) == (expected_span, remaining_encoding) + + +def test_labeled_span_encoder_decoder_parse_unknown_mode(): + """Test the LabeledSpanEncoderDecoder class.""" + + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode="unknown", + ) + with pytest.raises(ValueError) as excinfo: + encoder_decoder.parse([0, 3, 4], [], 6) + assert str(excinfo.value) == "unknown mode: unknown" + + +@pytest.mark.parametrize("mode", ["label_indices", "indices_label"]) +def test_labeled_span_encoder_decoder_parse_incomplete(mode): + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode=mode, + ) + if mode == "label_indices": + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledSpan" + assert excinfo.value.follow_up_candidates == [0, 1] + elif mode == "indices_label": + encoded_span = encoder_decoder.span_encoder_decoder.encode(Span(start=1, end=2)) + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoded_span, [], 6) + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledSpan" + assert excinfo.value.follow_up_candidates == [0, 1] + else: + raise ValueError(f"unknown mode: {mode}") + + +@pytest.mark.parametrize("mode", ["label_indices", "indices_label"]) +def test_labeled_span_encoder_decoder_parse_with_previous_exact_overlap(mode): + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode=mode, + ) + expected_span = LabeledSpan(start=1, end=3, label="A") + remaining_encoding = [3, 4] + # encoding of the expected span + remaining encoding + encoding = encoder_decoder.encode(expected_span) + remaining_encoding + other_span = LabeledSpan(start=1, end=3, label="B") + with pytest.raises(DecodingSpanOverlapException) as excinfo: + encoder_decoder.parse(encoding, [other_span], 6) + assert ( + str(excinfo.value) + == f"the encoded span {expected_span} overlaps with another span with a different label: B" + ) + assert excinfo.value.remaining == remaining_encoding + + +def test_labeled_multi_span_encoder_decoder(): + """Test the LabeledMultiSpanEncoderDecoder class.""" + + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledMultiSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + ) + # encode and decode a single span with two slices and label A + span = LabeledMultiSpan(slices=((1, 2), (4, 6)), label="A") + encoding = [3, 4, 6, 8, 0] + assert encoder_decoder.encode(span) == encoding + assert encoder_decoder.decode(encoding) == span + + # encoding empty slices are not allowed + with pytest.raises(EncodingEmptySlicesException) as excinfo: + encoder_decoder.encode(LabeledMultiSpan(slices=(), label="A")) + assert str(excinfo.value) == "LabeledMultiSpan must have at least one slice to encode it." + + # decoding an odd number of encoding entries is required for decoding + with pytest.raises(DecodingLengthException) as excinfo: + encoder_decoder.decode([3, 4, 6, 8]) + assert ( + str(excinfo.value) + == "an odd number of encoding entries is required for decoding a LabeledMultiSpan, but got 4" + ) + + +def test_labeled_multi_span_encoder_decoder_parse(): + """Test the LabeledMultiSpanEncoderDecoder class.""" + + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledMultiSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + ) + expected_span = LabeledMultiSpan(slices=((1, 2), (4, 6)), label="A") + encoding = [3, 4, 6, 8, 0] + remaining_encoding = [0, 1] + # encoding of the expected span + remaining encoding + assert encoder_decoder.parse(encoding + remaining_encoding, [], 10) == ( + expected_span, + remaining_encoding, + ) + + +def test_labeled_multi_span_encoder_decoder_parse_incomplete(): + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledMultiSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + ) + encoding = encoder_decoder.encode(LabeledMultiSpan(slices=((1, 2), (4, 6)), label="A")) + assert encoding == [3, 4, 6, 8, 0] + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect a start index, so the follow-up candidates are [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + # but we need to add the offset of 2 + assert excinfo.value.follow_up_candidates == [2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:1], [], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span, i.e. [2, 3, 4, 5, 6, 7, 8, 9, 10], + # but we need to add the offset of 2 + assert excinfo.value.follow_up_candidates == [4, 5, 6, 7, 8, 9, 10, 11, 12] + + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:2], [], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledMultiSpan" + # we can follow up with 1) a label encoding, i.e. [0, 1], or 2) the start index of a new span encoding + # not covered by the previous slice which is Span(1, 2), i.e. [0] + [2, 3, 4, 5, 6, 7, 8, 9, 10], + # but we also allow to generate the exact same span again, so the index 1 is also allowed + # and we need to add the offset of 2 + assert excinfo.value.follow_up_candidates == [0, 1] + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:3], [], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span, so the follow-up candidates are [7, 13) + assert excinfo.value.follow_up_candidates == list(range(7, 13)) + + +def test_labeled_multi_span_encoder_decoder_parse_incomplete_with_previous_annotations(): + label2id = {"A": 0, "B": 1} + encoder_decoder = LabeledMultiSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + ) + # a span before the current start index should not affect the follow-up candidates + other_span_before = LabeledMultiSpan(slices=((0, 1), (2, 3)), label="A") + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [other_span_before], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect a start index which can be in between the slices of the other_span_before or after the last slice, + # but we also allow to generate the exact same spans again, so the indices [0, 2] are also allowed + # i.e. [0] + [1] + [2] + [3, 4, 5, 6, 7, 8, 9], but we need to add the offset of 2 + assert excinfo.value.follow_up_candidates == [2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + + # a span after the current start index should limit the follow-up candidates + other_span_after = LabeledMultiSpan(slices=((5, 6), (7, 8)), label="A") + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([3], [other_span_after], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span and the span should not overlap with other_span_after, + # so allowed end indices are [2, 3, 4, 5], but we need to add the offset of 2 + assert excinfo.value.follow_up_candidates == [4, 5, 6, 7] + + entangled_span = LabeledMultiSpan(slices=((0, 2), (4, 5)), label="A") + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([4, 5], [entangled_span], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledMultiSpan" + # we expect either a label encoding, i.e. [0, 1], or a new span encoding, but the start index is + # not allowed to be within any of the slices of the entangled_span, i.e. [0, 1] and [4], + # and not within the span itself, i.e. [2, 3], but we allow to generate the exact same spans again, + # so the indices 0, 4, and 2 are also allowed, so the allowed end indices are + # [0] + [2] + [3] + [4] + [5, 6, 7, 8, 9], but we need to add the offset of 2 + assert excinfo.value.follow_up_candidates == [0, 1] + [2] + [4] + [5] + [6] + [7, 8, 9, 10, 11] + + @pytest.mark.parametrize( "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"] ) @@ -439,3 +915,350 @@ def test_binary_relation_encoder_decoder_wrong_label_index(): encoder_decoder.decode([1, 2, 3, 4, 5, 6, 7]) assert str(excinfo.value) == "unknown label id: 7 (label2id: {'A': 0, 'B': 1, 'C': 2})" assert excinfo.value.identifier == "label" + + +@pytest.mark.parametrize( + "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"] +) +def test_binary_relation_encoder_decoder_parse(mode): + """Test the BinaryRelationEncoderDecoder class.""" + + label2id = {"A": 0, "B": 1, "C": 2} + labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode="indices_label", + ) + encoder_decoder = BinaryRelationEncoderDecoder( + head_encoder_decoder=labeled_span_encoder_decoder, + tail_encoder_decoder=labeled_span_encoder_decoder, + label2id=label2id, + mode=mode, + ) + expected_relation = BinaryRelation( + head=LabeledSpan(start=1, end=2, label="A"), + tail=LabeledSpan(start=3, end=4, label="B"), + label="C", + ) + encoding = encoder_decoder.encode(expected_relation) + remaining_encoding = [2, 3] + # encoding of the expected relation + remaining encoding + assert encoder_decoder.parse(encoding + remaining_encoding, [], 10) == ( + expected_relation, + remaining_encoding, + ) + + +@pytest.mark.parametrize( + "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"] +) +def test_binary_relation_encoder_decoder_parse_loop_dummy_relation(mode): + """Test the BinaryRelationEncoderDecoder class.""" + + label2id = {"A": 0, "B": 1, "C": 2, "N": 3} + labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode="indices_label", + ) + encoder_decoder = BinaryRelationEncoderDecoder( + head_encoder_decoder=labeled_span_encoder_decoder, + tail_encoder_decoder=labeled_span_encoder_decoder, + label2id=label2id, + mode=mode, + loop_dummy_relation_name="L", + none_label="N", + ) + expected_relation = BinaryRelation( + head=LabeledSpan(start=1, end=2, label="A"), + tail=LabeledSpan(start=1, end=2, label="A"), + label="L", + ) + encoding = encoder_decoder.encode(expected_relation) + remaining_encoding = [6, 7, 8] + # encoding of the expected relation + remaining encoding + assert encoder_decoder.parse(encoding + remaining_encoding, [], 10) == ( + expected_relation, + remaining_encoding, + ) + + +@pytest.mark.parametrize("mode", ["head_tail_label", "label_head_tail"]) +def test_binary_relation_encoder_decoder_parse_incomplete(mode): + span_label2id = {"A": 0, "B": 1} + relation_label2id = {"C": 2, "N": 3} + labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset( + offset=len(span_label2id) + len(relation_label2id) + ), + label2id=span_label2id, + mode="indices_label", + ) + encoder_decoder = BinaryRelationEncoderDecoder( + head_encoder_decoder=labeled_span_encoder_decoder, + tail_encoder_decoder=labeled_span_encoder_decoder, + label2id=relation_label2id, + mode=mode, + loop_dummy_relation_name="loop", + none_label="N", + ) + # Note: we use a tail that comes before the head on purpose to test the restricted follow-up candidates + relation = BinaryRelation( + head=LabeledSpan(start=3, end=4, label="A"), + tail=LabeledSpan(start=1, end=2, label="B"), + label="C", + ) + encoding = encoder_decoder.encode(relation) + # just to see the encoding + if mode.endswith("_label"): + assert encoding == [7, 8, 0, 5, 6, 1, 2] + elif mode.startswith("label_"): + assert encoding == [2, 7, 8, 0, 5, 6, 1] + else: + raise ValueError(f"unknown mode: {mode}") + + # check missing start index of first argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:1], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect a start index which can be any valid start index, i.e. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + # but we need to add the offset of 4 + assert excinfo.value.follow_up_candidates == [4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + + # check missing end index of first argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:1], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:2], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span, i.e. [4, 5, 6, 7, 8, 9, 10], + # but we need to add the offset of 4 + assert excinfo.value.follow_up_candidates == [8, 9, 10, 11, 12, 13, 14] + + # check missing label of first argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:2], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:3], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledSpan" + # we expect a label encoding + assert excinfo.value.follow_up_candidates == [0, 1] + + # check missing start index of second argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:3], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:4], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as BinaryRelation" + # we expect + # 1) the none_label index, which is 3, or + # 2) a start index which can be any index that is not covered by the first + # argument span which is LabeledSpan(start=3, end=4, label="A"), + # but we allow to generate the exact same span again, i.e. [0, 1, 2] + [3] + [4, 5, 6, 7, 8, 9], + # but we need to add the offset of 4 + assert excinfo.value.follow_up_candidates == [3] + [4, 5, 6] + [7] + [8, 9, 10, 11, 12, 13] + + # check missing end index of second argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:4], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:5], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span and the span should not overlap with the first argument span + # which is LabeledSpan(start=3, end=4, label="A"), so allowed end indices are [2, 3] + # but we need to add the offset of 4 + assert excinfo.value.follow_up_candidates == [6, 7] + + # check missing label of second argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:5], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:6], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledSpan" + # we expect a label encoding + assert excinfo.value.follow_up_candidates == [0, 1] + + # check missing label + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:6], [], 10) + # we expect a relation label encoding + assert excinfo.value.follow_up_candidates == [2] + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [], 10) + # we expect a relation label encoding, but can not exclude the none label encoding so far + assert excinfo.value.follow_up_candidates == [2, 3] + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as BinaryRelation" + + # check loop dummy relation + encoding = encoder_decoder.encode( + BinaryRelation( + head=LabeledSpan(start=3, end=4, label="A"), + tail=LabeledSpan(start=3, end=4, label="A"), + label="loop", + ) + ) + if mode.endswith("_label"): + assert encoding == [7, 8, 0, 3, 3, 3, 3] + elif mode.startswith("label_"): + assert encoding == [3, 7, 8, 0, 3, 3, 3] + else: + raise ValueError(f"unknown mode: {mode}") + + # check missing second none id in second argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:4], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:5], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as BinaryRelation" + # we expect the none_label index, which is 3 + assert excinfo.value.follow_up_candidates == [3] + + # check missing third none id in second argument + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:5], [], 10) + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:6], [], 10) + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as BinaryRelation" + # we expect the none_label index, which is 3 + assert excinfo.value.follow_up_candidates == [3] + + # check missing none id as loop dummy relation id + if mode.endswith("_label"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:6], [], 10) + # we expect the none_label index, which is 3 + assert excinfo.value.follow_up_candidates == [3] + elif mode.startswith("label_"): + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [], 10) + # we expect the none_label index, which is 3, but can not exclude the relation label encoding so far + assert excinfo.value.follow_up_candidates == [2, 3] + else: + raise ValueError(f"unknown mode: {mode}") + assert str(excinfo.value) == "the encoding has not enough values to decode as BinaryRelation" + + +def test_binary_relation_encoder_decoder_parse_incomplete_with_previous_annotations(): + label2id = {"A": 0, "B": 1, "C": 2} + labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( + span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), + label2id=label2id, + mode="indices_label", + ) + encoder_decoder = BinaryRelationEncoderDecoder( + head_encoder_decoder=labeled_span_encoder_decoder, + tail_encoder_decoder=labeled_span_encoder_decoder, + label2id=label2id, + mode="head_tail_label", + ) + relation = BinaryRelation( + head=LabeledSpan(start=3, end=4, label="A"), + tail=LabeledSpan(start=1, end=2, label="B"), + label="C", + ) + encoding = encoder_decoder.encode(relation) + assert encoding == [6, 7, 0, 4, 5, 1, 2] + surrounding_relation = BinaryRelation( + head=LabeledSpan(start=0, end=1, label="A"), + tail=LabeledSpan(start=6, end=7, label="B"), + label="C", + ) + + # test the first span start index + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse([], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect a start index which can be any index that is not covered by the surrounding_relation argument spans, + # which are LabeledSpan(start=0, end=1, label="A") and LabeledSpan(start=6, end=7, label="B"), + # but we allow to generate the exact same span again, i.e. [0] + [1, 2, 3, 4, 5] + [6] + [7, 8, 9] + # but we need to add the offset of 3 + assert excinfo.value.follow_up_candidates == [3] + [4, 5, 6, 7, 8] + [9] + [10, 11, 12] + + # test the first span end index + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:1], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span and the span should not overlap with the + # surrounding_relation argument spans, so allowed end indices are [4, 5, 6], + # but we need to add the offset of 3 + assert excinfo.value.follow_up_candidates == [7, 8, 9] + + # test the first span label + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:2], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledSpan" + # we expect a label encoding + assert excinfo.value.follow_up_candidates == [0, 1, 2] + + # test the second span start index + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:3], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect a start index which can be any index that is not covered by the surrounding_relation argument spans, + # which are LabeledSpan(start=0, end=1, label="A") and LabeledSpan(start=6, end=7, label="B"), + # and not by the first argument span which is LabeledSpan(start=3, end=4, label="A"), + # but we allow to generate the exact same spans again, + # i.e. [0] + [1, 2] + [3] + [4, 5] + [6] + [7, 8, 9], but we need to add the offset of 3 + assert excinfo.value.follow_up_candidates == [3] + [4, 5] + [6] + [7, 8] + [9] + [10, 11, 12] + + # test the second span end index + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:4], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as Span" + # we expect an end index of a none-empty Span and the span should not overlap with the surrounding_relation + # argument spans which are LabeledSpan(start=0, end=1, label="A") and LabeledSpan(start=6, end=7, label="B"), + # or the first argument span which is LabeledSpan(start=3, end=4, label="A"), so allowed end indices are [2, 3], + # but we need to add the offset of 3 + assert excinfo.value.follow_up_candidates == [5, 6] + + # test the second span label + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:5], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as LabeledSpan" + # we expect a label encoding + assert excinfo.value.follow_up_candidates == [0, 1, 2] + + # test the relation label + with pytest.raises(IncompleteEncodingException) as excinfo: + encoder_decoder.parse(encoding[:6], [surrounding_relation], 10) + assert str(excinfo.value) == "the encoding has not enough values to decode as BinaryRelation" + # we expect a label encoding + assert excinfo.value.follow_up_candidates == [0, 1, 2] diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 65a217240..da32ec49a 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -1,5 +1,6 @@ import logging import pickle +import random from dataclasses import asdict, dataclass from typing import Dict, List, Set @@ -211,6 +212,7 @@ def test_prepared_config(taskmodule, config): "entities": "labeled_spans", "relations": "binary_relations", }, + "constrain_with_previous_records": True, "constrained_generation": False, "label_tokens": None, "label_representations": None, @@ -238,6 +240,7 @@ def test_prepared_config(taskmodule, config): "entities": "labeled_spans", "relations": "binary_relations", }, + "constrain_with_previous_records": True, "constrained_generation": False, "label_tokens": None, "label_representations": None, @@ -285,14 +288,14 @@ def test_target_encoding(target_encoding, taskmodule): [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -306,7 +309,7 @@ def test_target_encoding(target_encoding, taskmodule): [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -351,14 +354,14 @@ def test_build_constraints(taskmodule, task_encoding, config): [[0, 1], [0], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], # 14 [[0, 0], [0], [0, 0, 0], [0], [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]], # 14 [[0, 0], [0], [1, 1, 1], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 5 - [[0, 0], [1], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]], # 11 + [[0, 0], [1], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], # 11 [[0, 0], [0], [0, 0, 0], [0], [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0]], # 12 [[0, 0], [0], [1, 1, 1], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 3 [[0, 0], [0], [0, 0, 0], [1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 6 - [[0, 1], [0], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], # 17 + [[0, 1], [0], [0, 0, 0], [0], [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]], # 17 [[0, 0], [0], [0, 0, 0], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]], # 17 [[0, 0], [0], [1, 1, 1], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 4 - [[0, 0], [1], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1]], # 2 + [[0, 0], [1], [0, 0, 0], [0], [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]], # 2 [[0, 0], [1], [0, 0, 0], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 2 [[0, 0], [1], [0, 0, 0], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 2 [[0, 0], [1], [0, 0, 0], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 2 @@ -377,7 +380,7 @@ def test_build_constraints(taskmodule, task_encoding, config): [[0, 1], [0], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], # 14 [[0, 0], [0], [0, 0, 0], [0], [0, 0, 0, 0, 0, 0, 0, 1, 1, 1]], # 14 [[0, 0], [0], [1, 1, 1], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 5 - [[0, 0], [1], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 0, 1, 1]], # 11 + [[0, 0], [1], [0, 0, 0], [0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], # 11 [[0, 0], [0], [0, 0, 0], [0], [0, 0, 0, 0, 1, 1, 1, 0, 0, 0]], # 12 [[0, 0], [0], [1, 1, 1], [0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 3 [[0, 0], [0], [0, 0, 0], [1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # 6 @@ -387,12 +390,23 @@ def test_build_constraints(taskmodule, task_encoding, config): raise Exception(f"unknown config: {config}") +def build_constraint_mask( + taskmodule, + previous_ids: List[int], + input_len: int, +) -> torch.LongTensor: + follow_up_candidates = taskmodule.get_follow_up_candidates( + previous_ids=previous_ids, input_len=input_len + ) + return taskmodule.follow_up_candidates_to_mask(follow_up_candidates, input_len) + + def test_build_constraint(taskmodule): target_ids = [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1] input_len = 13 # empty previous_ids - constraint = taskmodule._build_constraint(previous_ids=[], input_len=input_len) + constraint = build_constraint_mask(taskmodule, previous_ids=[], input_len=input_len) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) # allow eos and all offsets @@ -405,7 +419,7 @@ def test_build_constraint(taskmodule): ] # just first span start - constraint = taskmodule._build_constraint(previous_ids=[14], input_len=input_len) + constraint = build_constraint_mask(taskmodule, previous_ids=[14], input_len=input_len) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) # allow all offsets after first span start @@ -418,7 +432,7 @@ def test_build_constraint(taskmodule): ] # first span start and end - constraint = taskmodule._build_constraint(previous_ids=[14, 14], input_len=input_len) + constraint = build_constraint_mask(taskmodule, previous_ids=[14, 14], input_len=input_len) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) # allow all span ids @@ -431,7 +445,7 @@ def test_build_constraint(taskmodule): ] # first span start, end, and label - constraint = taskmodule._build_constraint(previous_ids=[14, 14, 5], input_len=input_len) + constraint = build_constraint_mask(taskmodule, previous_ids=[14, 14, 5], input_len=input_len) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) # allow none and all offsets except offsets covered by first span @@ -440,11 +454,13 @@ def test_build_constraint(taskmodule): [1], [0, 0, 0], [0], - [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], ] # first span, and second span start - constraint = taskmodule._build_constraint(previous_ids=[14, 14, 5, 11], input_len=input_len) + constraint = build_constraint_mask( + taskmodule, previous_ids=[14, 14, 5, 11], input_len=input_len + ) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) # allow all offsets after second span start, but not after first span start @@ -457,8 +473,8 @@ def test_build_constraint(taskmodule): ] # first span, and second span start and end - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5, 11, 12], input_len=input_len + constraint = build_constraint_mask( + taskmodule, previous_ids=[14, 14, 5, 11, 12], input_len=input_len ) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) @@ -472,8 +488,8 @@ def test_build_constraint(taskmodule): ] # first span, and second span - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5, 11, 12, 3], input_len=input_len + constraint = build_constraint_mask( + taskmodule, previous_ids=[14, 14, 5, 11, 12, 3], input_len=input_len ) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) @@ -489,8 +505,8 @@ def test_build_constraint(taskmodule): # fist span, and (1 to 3)-times none for i in range(1, 3): none_ids = [2] * i - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5] + none_ids, input_len=input_len + constraint = build_constraint_mask( + taskmodule, previous_ids=[14, 14, 5] + none_ids, input_len=input_len ) # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) @@ -504,8 +520,8 @@ def test_build_constraint(taskmodule): ] # contains eos - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5, 11, 12, 3, 6, 1], input_len=input_len + constraint = build_constraint_mask( + taskmodule, previous_ids=[14, 14, 5, 11, 12, 3, 6, 1], input_len=input_len ) # [bos, eos/pad], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) @@ -518,6 +534,8 @@ def test_build_constraint(taskmodule): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ] + # TODO: test with decoded_relations + def test_maybe_log_example(taskmodule, task_encoding, caplog, config): original_log_first_n_examples = taskmodule.log_first_n_examples @@ -596,14 +614,14 @@ def test_collate(batch, taskmodule): [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -627,7 +645,7 @@ def test_collate(batch, taskmodule): [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], @@ -637,7 +655,7 @@ def test_collate(batch, taskmodule): [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1, -1, -1, -1], [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, -1, -1, -1, -1, -1], + [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], @@ -736,13 +754,15 @@ def test_annotations_from_output(task_encodings, task_outputs, taskmodule): ) -def get_default_taskmodule(**kwargs): - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - labels_per_layer={ +def get_default_taskmodule(labels_per_layer=None, **kwargs): + if labels_per_layer is None: + labels_per_layer = { "labeled_spans": ["content", "person", "topic"], "binary_relations": ["is_about"], - }, + } + taskmodule = PointerNetworkTaskModuleForEnd2EndRE( + tokenizer_name_or_path="facebook/bart-base", + labels_per_layer=labels_per_layer, **kwargs, ) taskmodule.post_prepare() @@ -794,7 +814,13 @@ def test_configure_model_metric(): values = metric.compute() assert values == { "exact_encoding_matches": 0.5, - "decoding_errors": {"correct": 0.5, "len": 0.25, "order": 0.25, "all": 0.5}, + "decoding_errors": { + "correct": 0.25, + "negative_index": 0.5, + "order": 0.125, + "label": 0.125, + "all": 0.75, + }, "labeled_spans": { "person": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, "topic": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, @@ -862,7 +888,8 @@ def test_prefix_allowed_tokens_fn_with_maximum(): batch_id=0, input_ids=add_previous_input_ids[:4], maximum=20 ) # allow none [2] and all offsets except offsets covered by first span [14] - assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19] + # TODO: no, 14 is also allowed! because of generating the exact same spans is allowed. or better not? + assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] # first span, and second span start allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( @@ -882,15 +909,15 @@ def test_prefix_allowed_tokens_fn_with_maximum(): allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( batch_id=0, input_ids=add_previous_input_ids[:7], maximum=20 ) - # allow all relation ids - assert allowed_ids == [6] + # allow all relation ids. we also allow the eos [1] because two entries are sampled each step + assert allowed_ids == [1, 6] # entry begins (second entry) allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( batch_id=0, input_ids=add_previous_input_ids[:8], maximum=20 ) - # allow eos [1] and all offsets [7..19] - assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + # allow eos [1] and all offsets [7..19] except the ones covered by the first entry, i.e. [12] + assert allowed_ids == [1, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19] # first span start allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( @@ -911,35 +938,39 @@ def test_prefix_allowed_tokens_fn_with_maximum(): batch_id=0, input_ids=add_previous_input_ids[:11], maximum=20 ) # allow none [2] and all offsets except offsets covered by first span [17] - assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19] + # TODO: no, 17 is also allowed! because of generating the exact same spans is allowed. or better not? + assert allowed_ids == [2, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19] # first span, and none allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( batch_id=0, input_ids=add_previous_input_ids[:12], maximum=20 ) # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else - assert allowed_ids == [2] + # we also allow the eos [1] because two entries are sampled each step + assert allowed_ids == [1, 2] # first span, and none, and none allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( batch_id=0, input_ids=add_previous_input_ids[:13], maximum=20 ) # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else - assert allowed_ids == [2] + # we also allow the eos [1] because two entries are sampled each step + assert allowed_ids == [1, 2] # first span, and none, and none, and none allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( batch_id=0, input_ids=add_previous_input_ids[:14], maximum=20 ) # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else - assert allowed_ids == [2] + # we also allow the eos [1] because two entries are sampled each step + assert allowed_ids == [1, 2] # first span, and none, and none, and none, and none (second entry is complete) allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( batch_id=0, input_ids=add_previous_input_ids[:15], maximum=20 ) - # allow eos [1] and all offsets [7..19] - assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] + # allow eos [1] and all offsets [7..19], except the ones covered by the first entry, i.e. [12] + assert allowed_ids == [1, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19] # got an eos, so the sequence is complete allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( @@ -947,3 +978,21 @@ def test_prefix_allowed_tokens_fn_with_maximum(): ) # allow only pad [1] (same as eos) because the sequence is complete assert allowed_ids == [1] + + +def test_decode_annotations_fuzzing(): + taskmodule = get_default_taskmodule( + labels_per_layer={ + "binary_relations": ["contradicts", "parts_of_same", "semantically_same", "supports"], + "labeled_spans": ["background_claim", "data", "own_claim"], + } + ) + random.seed(42) + input_length = 100 + output_length = 30 + for _ in range(1000): + encoding = random.sample(range(0, input_length + taskmodule.pointer_offset), output_length) + encoding_without_bos = [idx for idx in encoding if idx != taskmodule.bos_id] + taskmodule.annotation_encoder_decoder.parse_with_error_handling( + encoding=encoding_without_bos, input_length=input_length, stop_ids=[taskmodule.eos_id] + )