Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pointer network for multi spans #63

Open
wants to merge 66 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
31831f2
add use_multi_spans property
ArneBinder Feb 5, 2024
9304ff4
add todos
ArneBinder Feb 5, 2024
6a8f92a
convert all AnnotationEncoderDecoder to GenerativeAnnotationEncoderDe…
ArneBinder Feb 7, 2024
9c7a5d5
add LabeledMultiSpanEncoderDecoder (still without tests)
ArneBinder Feb 7, 2024
b0d3766
minor
ArneBinder Feb 7, 2024
c379cb1
handle none_id in BinaryRelationEncoderDecoder.parse()
ArneBinder Feb 7, 2024
81600fb
fix LabeledMultiSpanEncoderDecoder.parse()
ArneBinder Feb 8, 2024
7f1aee5
rename exception
ArneBinder Feb 8, 2024
8d11afd
raise error for nested / overlapping parsed spans
ArneBinder Feb 8, 2024
f14e9c2
fix IncompleteEncodingException for SpanEncoderDecoderWithOffset.parse
ArneBinder Feb 11, 2024
3419813
better constrain the type for span_encoder_decoder in LabeledSpanEnco…
ArneBinder Feb 11, 2024
485f680
disentangle raising DecodingSpanOverlapException and DecodingSpanNest…
ArneBinder Feb 11, 2024
8acd880
add EncodingException
ArneBinder Feb 12, 2024
fd7bf72
add DecodingEmptySpanException, EncodingEmptySpanException, spans_hav…
ArneBinder Feb 12, 2024
c0ca07c
add documentation
ArneBinder Feb 12, 2024
3f83d5b
assert in SpanEncoderDecoder.encode() that span is not empty
ArneBinder Feb 12, 2024
1f215e4
fix SpanEncoderDecoder.parse() and add tests
ArneBinder Feb 12, 2024
1c9a1cc
add tests for SpanEncoderDecoderWithOffset.parse()
ArneBinder Feb 12, 2024
0a2803c
fix SpanEncoderDecoderWithOffset.parse()
ArneBinder Feb 13, 2024
195247e
add tests for LabeledSpanEncoderDecoder.parse()
ArneBinder Feb 13, 2024
4c9518c
improve thrown Exceptions and add tests for LabeledMultiSpanEncoderDe…
ArneBinder Feb 13, 2024
259e5ed
fix SpanEncoderDecoder.parse() for empty encoding
ArneBinder Feb 13, 2024
33b8fad
fix LabeledMultiSpanEncoderDecoder.parse() and add tests
ArneBinder Feb 13, 2024
5270ace
add test_labeled_multi_span_encoder_decoder_parse_incomplete_with_pre…
ArneBinder Feb 13, 2024
d1773ec
add test_binary_relation_encoder_decoder_parse_incomplete() and test_…
ArneBinder Feb 13, 2024
8095d63
tests for parsing loop dummy relations
ArneBinder Feb 13, 2024
b909b78
allow to generate exact same spans again
ArneBinder Feb 13, 2024
6d264f3
adjust PointerNetworkTaskModuleForEnd2EndRE for LabeledMultiSpan (par…
ArneBinder Feb 13, 2024
e902a32
decoding exceptions contain remaining encoding when raised in parse()
ArneBinder Feb 13, 2024
e55436d
use relation_encoder_decoder.parse() in decode_relations()
ArneBinder Feb 13, 2024
379d715
rename error identifier "index" to "negative_index"
ArneBinder Feb 13, 2024
7719079
outsource _parse_label and restrict labels if none was found
ArneBinder Feb 13, 2024
bfd7148
use relation_encoder_decoder.parse() in build_constraint(s)
ArneBinder Feb 13, 2024
432f319
implement GenerativeAnnotationEncoderDecoderWithParseWithErrors with …
ArneBinder Feb 13, 2024
6054d38
minor fix
ArneBinder Feb 13, 2024
2bff31a
cleanup
ArneBinder Feb 13, 2024
512b015
ensure that parse_with_error_handling() does not get stuck in an infi…
ArneBinder Feb 14, 2024
17e8265
fix forwarding DecodingExceptions in SpanEncoderDecoderWithOffset.par…
ArneBinder Feb 14, 2024
573a264
cleanup sanity check in encode_annotations()
ArneBinder Feb 14, 2024
50071e9
cleanup follow-up candidate / constraints creation
ArneBinder Feb 14, 2024
fd32a81
fix PrefixConstrainedLogitsProcessorWithMaximum
ArneBinder Feb 14, 2024
5331d06
raise an exception if no follow-up candidates are found
ArneBinder Feb 14, 2024
28eb200
enforce continuations when previously decoded spans are re-generated
ArneBinder Feb 14, 2024
cc54e47
harden _prefix_allowed_tokens_fn_with_maximum(): if the decoding fail…
ArneBinder Feb 14, 2024
12e4b44
harden _prefix_allowed_tokens_fn_with_maximum(): If there is only one…
ArneBinder Feb 14, 2024
116e8ac
_prefix_allowed_tokens_fn_with_maximum(): log a warning if get_follow…
ArneBinder Feb 14, 2024
f3fa6ea
LabeledSpanEncoderDecoder.parse(): check for same span with different…
ArneBinder Feb 14, 2024
6eee47f
LabeledMultiSpanEncoderDecoder.parse(): check for overlap with previo…
ArneBinder Feb 15, 2024
1ff433d
harden decode_annotations(): catch any exception
ArneBinder Feb 15, 2024
db009c7
increase test coverage
ArneBinder Feb 15, 2024
759936c
better error message
ArneBinder Feb 15, 2024
5bd3b2d
add test_decode_annotations_issue()
ArneBinder Feb 16, 2024
b48ac8b
fix spans_have_overlap() and SpanEncoderDecoder.parse()
ArneBinder Feb 16, 2024
31a2fb7
add test_decode_annotations_fuzzing()
ArneBinder Feb 16, 2024
c1486e6
fix variable name
ArneBinder Feb 16, 2024
5788099
speed up build_constraints()
ArneBinder Feb 16, 2024
22d0d35
add cache_decoded and to speed up get_follow_up_candidates() and use …
ArneBinder Feb 16, 2024
6534f94
disable cache because it does not work as expected
ArneBinder Feb 16, 2024
1948b9d
add comment
ArneBinder Feb 16, 2024
c6b13e2
fixes regarding #64
ArneBinder Feb 19, 2024
9b23bf7
add disrespect_decoded_annotations parameter to parse_with_error_hand…
ArneBinder Feb 19, 2024
9768e65
add constrain_with_previous_records parameter to PointerNetworkTaskMo…
ArneBinder Feb 19, 2024
e8637e1
add test case
ArneBinder Feb 27, 2024
34273f2
use are_nested and have_overlap from pie_modules.utils.span
ArneBinder Feb 27, 2024
1aff783
disentangle and generalize naming
ArneBinder Mar 4, 2024
35d3412
pre-commit
ArneBinder Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion src/pie_modules/taskmodules/common/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import abc
import logging
from collections import defaultdict
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar

from pytorch_ie import Annotation

logger = logging.getLogger(__name__)

# Annotation Encoding type: encoding for a single annotation
AE = TypeVar("AE")
# Annotation type
Expand All @@ -17,9 +21,20 @@ class DecodingException(Exception, Generic[AE], abc.ABC):

identifier: str

def __init__(self, message: str, encoding: AE):
def __init__(self, message: str, encoding: AE, remaining: Optional[AE] = None):
self.message = message
self.encoding = encoding
self.remaining = remaining


class EncodingException(Exception, Generic[A], abc.ABC):
"""Exception raised when encoding fails."""

identifier: str

def __init__(self, message: str, annotation: A):
self.message = message
self.annotation = annotation


class AnnotationEncoderDecoder(abc.ABC, Generic[A, AE]):
Expand All @@ -32,3 +47,64 @@ def encode(self, annotation: A, metadata: Optional[Dict[str, Any]] = None) -> AE
@abc.abstractmethod
def decode(self, encoding: AE, metadata: Optional[Dict[str, Any]] = None) -> A:
pass


class GenerativeAnnotationEncoderDecoder(AnnotationEncoderDecoder[A, AE], abc.ABC):
"""Base class for generative annotation encoders and decoders."""

@abc.abstractmethod
def parse(self, encoding: AE, decoded_annotations: List[A], text_length: int) -> Tuple[A, AE]:
"""Parse the encoding and return the decoded annotation and the remaining encoding."""
pass


class GenerativeAnnotationEncoderDecoderWithParseWithErrors(
Generic[A], GenerativeAnnotationEncoderDecoder[A, List[int]], abc.ABC
):
KEY_INVALID_CORRECT = "correct"

def parse_with_error_handling(
self,
encoding: List[int],
input_length: int,
stop_ids: List[int],
errors: Optional[Dict[str, int]] = None,
decoded_annotations: Optional[List[A]] = None,
disrespect_decoded_annotations: bool = False,
) -> Tuple[List[A], Dict[str, int], List[int]]:
errors = errors or defaultdict(int)
decoded_annotations = decoded_annotations or []
valid_encoding: A
successfully_decoded: List[int] = []
remaining = encoding
prev_len = len(remaining)
while len(remaining) > 0:
if remaining[0] in stop_ids:
# we discard everything after any stop id
break
try:
valid_encoding, remaining = self.parse(
encoding=remaining,
decoded_annotations=decoded_annotations
if not disrespect_decoded_annotations
else [],
text_length=input_length,
)
decoded_annotations.append(valid_encoding)
errors[self.KEY_INVALID_CORRECT] += 1
successfully_decoded = encoding[: len(encoding) - len(remaining)]
except DecodingException as e:
if e.remaining is None:
raise ValueError(f"decoding exception did not return remaining encoding: {e}")
errors[e.identifier] += 1
remaining = e.remaining

# if we did not consume any ids, we discard the first remaining one
if len(remaining) == prev_len:
logger.warning(
f"parse did not consume any ids, discarding first id from {remaining}"
)
remaining = remaining[1:]
prev_len = len(remaining)

return decoded_annotations, dict(errors), encoding[len(successfully_decoded) :]
Loading
Loading