Skip to content

Commit

Permalink
Reversed logic of triplet storage, making it more simple and thus les…
Browse files Browse the repository at this point in the history
…s error-prone
  • Loading branch information
KasperFyhn committed Oct 31, 2023
1 parent 915645c commit 4c34006
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ def set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None:
# Output empty lists if empty doc or no extractions above threshold
if not predictions["extraction"]:
setattr(doc._, "relation_confidence", []) # type: ignore
setattr(doc._, "relation_head", []) # type: ignore
setattr(doc._, "relation_relation", []) # type: ignore
setattr(doc._, "relation_tail", []) # type: ignore
setattr(doc._, "relation_triplets", []) # type: ignore
return

# concatenate wordpieces and concatenate extraction span. Handle new extraction
Expand All @@ -123,13 +121,7 @@ def set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None:
# Set doc level attributes
merged_confidence = [j for i in predictions["confidence"] for j in i]
setattr(doc._, "relation_confidence", merged_confidence) # type: ignore
setattr(doc._, "relation_head", aligned_extractions["head"]) # type: ignore
setattr(
doc._, # type: ignore
"relation_relation",
aligned_extractions["relation"],
)
setattr(doc._, "relation_tail", aligned_extractions["tail"]) # type: ignore
setattr(doc._, "relation_triplets", aligned_extractions) # type: ignore

def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def wp_span_to_token(
relation_span: List[List[int]],
wp_tokenid_mapping: Dict,
doc: Doc,
) -> Dict[str, List[Span]]:
) -> List[Tuple[Span, Span, Span]]:
"""Converts the wp span for each relation to spans.
Assumes that relations are contiguous
"""
relations = {"triplet": [], "head": [], "relation": [], "tail": []} # type: ignore
relations = [] # type: ignore
for triplet in relation_span:
# convert list of wordpieces in the extraction to a tuple of the span (start,
# end)
Expand All @@ -73,9 +73,7 @@ def wp_span_to_token(
relation = token_span_to_spacy_span(relation, doc)
tail = token_span_to_spacy_span(tail, doc)

relations["head"].append(head)
relations["relation"].append(relation)
relations["tail"].append(tail)
relations.append((head, relation, tail))
return relations


Expand Down Expand Up @@ -118,43 +116,39 @@ def idx_to_span(idx: Tuple[int, int], doc: Doc) -> Span:
return doc[start:end]


def install_extension(ext):
Doc.set_extension(ext + "_idxs", default=None)
Doc.set_extension(
ext,
setter=lambda doc, spans: setattr(
doc._,
ext + "_idxs",
[span_to_idx(span) for span in spans],
),
getter=lambda doc: [
idx_to_span(idx, doc) for idx in getattr(doc._, ext + "_idxs")
],
)


def install_extensions() -> None:
"""Sets extensions on the SpaCy Doc class.
Relation heads, relations and tails are stored internally as (start,
end) index tuples, but they are created with getters and setters
that map the index tuples to and from SpaCy Span objects. Full
triplets are retrieved by looping over the (assumed) aligned heads,
relations and tails. Confidence numbers are stored as is.
Relation triplets are stored internally as index tuples, but they
are created with getters and setters that map the index tuples to
and from SpaCy Span objects. Heads, relations and tails are
retrieved from the triplets. Confidence numbers are stored as is.
"""
for ext in ["relation_head", "relation_relation", "relation_tail"]:
install_extension(ext)
Doc.set_extension("relation_triplet_idxs", default=None)
Doc.set_extension(
"relation_triplets",
setter=lambda doc, triplets: setattr(
doc._,
"relation_triplet_idxs",
[tuple(span_to_idx(span) for span in triplet) for triplet in triplets],
),
getter=lambda doc: [
(head, rel, tail)
for head, rel, tail in zip(
doc._.relation_head,
doc._.relation_relation,
doc._.relation_tail,
)
tuple(idx_to_span(idx, doc) for idx in triplet_idx)
for triplet_idx in doc._.relation_triplet_idxs
],
)
Doc.set_extension(
"relation_head",
getter=lambda doc: [t[0] for t in doc._.relation_triplets],
)
Doc.set_extension(
"relation_relation",
getter=lambda doc: [t[1] for t in doc._.relation_triplets],
)
Doc.set_extension(
"relation_tail",
getter=lambda doc: [t[2] for t in doc._.relation_triplets],
)
Doc.set_extension("relation_confidence", default=None)


Expand Down
46 changes: 28 additions & 18 deletions tests/test_relationextraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,36 @@


def test_relationextraction_doc_extension():
"""Verifies the behavior of the relation triplet extensions and the lambdas
that back them."""
"""Verifies the behavior of the relation triplet extension and the lambdas
that back it."""
multi2oie_utils.install_extensions()

words = "this is a test".split()
words = "this is a test . the test seems cool".split()
vocab = Vocab(strings=words)
test_doc = Doc(vocab, words=words)

head = test_doc[0:1]
relation = test_doc[1:2]
tail = test_doc[2:4]

test_doc._.relation_head = [head]
test_doc._.relation_relation = [relation]
test_doc._.relation_tail = [tail]

# verify the extension that contains the actual indices
assert test_doc._.relation_head_idxs == [(0, 1)]
assert test_doc._.relation_relation_idxs == [(1, 2)]
assert test_doc._.relation_tail_idxs == [(2, 4)]

# verify that relation_triplets are correctly built from the other extensions
assert test_doc._.relation_triplets == [(head, relation, tail)]
this = test_doc[0:1]
is_ = test_doc[1:2]
a_test = test_doc[2:4]
the_test = test_doc[5:7]
seems = test_doc[7:8]
cool = test_doc[8:9]

test_doc._.relation_triplets = [(this, is_, a_test), (the_test, seems, cool)]

# check setter/getter mirroring
assert test_doc._.relation_triplets == [
(this, is_, a_test),
(the_test, seems, cool),
]

# check index extension
assert test_doc._.relation_triplet_idxs == [
((0, 1), (1, 2), (2, 4)),
((5, 7), (7, 8), (8, 9)),
]

# check heads, relations and tails
assert test_doc._.relation_head == [this, the_test]
assert test_doc._.relation_relation == [is_, seems]
assert test_doc._.relation_tail == [a_test, cool]

0 comments on commit 4c34006

Please sign in to comment.