Skip to content

Commit

Permalink
Major step to streamline data modelling for both relation extraction …
Browse files Browse the repository at this point in the history
…methods
  • Loading branch information
KasperFyhn committed Nov 1, 2023
1 parent 4c34006 commit d66d58b
Show file tree
Hide file tree
Showing 16 changed files with 193 additions and 113 deletions.
2 changes: 1 addition & 1 deletion paper/extract_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
from typing import List, Tuple
from spacy.tokens import Doc
from conspiracies.docprocessing.relationextraction.gptprompting.data_classes import (
from conspiracies.docprocessing.relationextraction.data_classes import (
DocTriplets,
)

Expand Down
14 changes: 11 additions & 3 deletions paper/src/ents_heads_extraction.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
"""Pipeline for headwords/entities extractions and frequency count."""
import multiprocessing
import os

import spacy
import torch
from tqdm import tqdm

from conspiracies.docprocessing.headwordextraction import contains_ents


def main():
nlp = spacy.load("en_core_web_lg")

test_sents = ["Mette Frederiksen is the Danish politician."]
test_sents = ["Mette Frederiksen is the Danish politician."] * 1000

config = {"confidence_threshold": 2.7, "model_args": {"batch_size": 10}}
nlp.add_pipe("relation_extractor", config=config)
nlp.add_pipe("heads_extraction")

pipe = nlp.pipe(test_sents)
# multiprocessing and torch with multiple threads result in a deadlock
torch.set_num_threads(1)

pipe = nlp.pipe(test_sents, n_process=os.cpu_count(), batch_size=25)

heads_spans = []
ents_spans = []

for d in pipe:
for d in tqdm(pipe):
for span in d._.relation_head:
heads_spans.append(span._.most_common_ancestor)
if span.ents:
Expand Down
Empty file.
7 changes: 5 additions & 2 deletions src/conspiracies/docprocessing/doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from spacy.language import Language
from spacy.tokens import Doc

from conspiracies.docprocessing.relationextraction.data_classes import (
install_extensions,
)
from conspiracies.docprocessing.relationextraction.gptprompting import (
DocTriplets,
SpanTriplet,
Expand All @@ -30,7 +33,7 @@ def _doc_from_json(json: dict, nlp: Language) -> Doc:
for triplet_json in json["semantic_triplets"]
]
if not Doc.has_extension("relation_triplets"):
Doc.set_extension("relation_triplets", default=[], force=True)
install_extensions()
doc._.relation_triplets = DocTriplets(span_triplets=triplets, doc=doc)
return doc

Expand Down Expand Up @@ -65,7 +68,7 @@ def docs_from_jsonl(
A list of docs with the extension `doc._.relation_triplets` set.
"""
if not Doc.has_extension("relation_triplets"):
Doc.set_extension("relation_triplets", default=[], force=True)
install_extensions(force=True)
docs = []
with jsonlines.open(path, "r") as reader:
for json in reader:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,17 @@ def __eq__(self, other) -> bool:
if not isinstance(other, SpanTriplet):
return False

triplet_is_equal = all(s1 == s2 for s1, s2 in zip(self.triplet, other.triplet))
# It does not make sense to compare Spans in different texts, but they can
# be different Doc objects!
if self.doc.text != other.doc.text:
return False

# Equality is checked for start, end and text such that two triplets in
# different Doc objects can be considered equal
triplet_is_equal = all(
s1.start == s2.start and s1.end == s2.end and s1.text == s2.text
for s1, s2 in zip(self.triplet, other.triplet)
)
return triplet_is_equal


Expand Down Expand Up @@ -828,3 +838,75 @@ def __eq__(self, other: Any) -> bool:
if s != o:
return False
return True


def span_to_idx(span: Span) -> Tuple[int, int]:
return span.start, span.end


def idx_to_span(idx: Tuple[int, int], doc: Doc) -> Span:
start, end = idx
return doc[start:end]


def install_extensions(force=False) -> None:
"""Sets extensions on the SpaCy Doc class.
Relation triplets are stored internally as index tuples, but they
are created with getters and setters that map the index tuples to
and from SpaCy Span objects. Heads, relations and tails are
retrieved from the triplets. Confidence numbers are stored as is.
"""
extensions = [
"relation_triplet_idxs",
"relation_triplets",
"relation_head",
"relation_relation",
"relation_tail",
"relation_confidence",
]
if not force and all(Doc.has_extension(ext) for ext in extensions):
return # job's done!

Doc.set_extension("relation_triplet_idxs", default=[], force=force)
Doc.set_extension(
"relation_triplets",
setter=lambda doc, triplets: setattr(
doc._,
"relation_triplet_idxs",
[
tuple(span_to_idx(span) for span in span_triplet.triplet)
for span_triplet in triplets
],
),
getter=lambda doc: DocTriplets(
span_triplets=[
SpanTriplet.from_tuple(
(
idx_to_span(idx[0], doc),
idx_to_span(idx[1], doc),
idx_to_span(idx[2], doc),
),
)
for idx in doc._.relation_triplet_idxs
],
doc=doc,
),
force=force,
)
Doc.set_extension(
"relation_head",
getter=lambda doc: [t.subject for t in doc._.relation_triplets],
force=force,
)
Doc.set_extension(
"relation_relation",
getter=lambda doc: [t.predicate for t in doc._.relation_triplets],
force=force,
)
Doc.set_extension(
"relation_tail",
getter=lambda doc: [t.object for t in doc._.relation_triplets],
force=force,
)
Doc.set_extension("relation_confidence", default=None, force=force)
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .data_classes import SpanTriplet, StringTriplet, DocTriplets # noqa F401
from conspiracies.docprocessing.relationextraction.data_classes import (
SpanTriplet, # noqa F401
StringTriplet, # noqa F401
DocTriplets, # noqa F401
)
from .prompt_templates import ( # noqa F401
MarkdownPromptTemplate1,
MarkdownPromptTemplate2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from spacy.training.example import Example

from conspiracies.registry import registry
from .data_classes import DocTriplets, SpanTriplet
from conspiracies.docprocessing.relationextraction.data_classes import (
DocTriplets,
SpanTriplet,
install_extensions,
)
from .prompt_apis import create_openai_chatgpt_prompt_api # noqa: F401


Expand Down Expand Up @@ -115,7 +119,7 @@ def __init__(
self.split_doc_fn = None

if not Doc.has_extension("relation_triplets") or force:
Doc.set_extension("relation_triplets", default=None, force=force)
install_extensions(force=force)

def combine_docs(self, docs: List[Doc]) -> Doc:
"""Combine a list of docs into a single doc."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from confection import registry
from spacy.tokens import Doc

from .data_classes import SpanTriplet, StringTriplet
from conspiracies.docprocessing.relationextraction.data_classes import (
SpanTriplet,
StringTriplet,
)


class PromptTemplate:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

from .knowledge_triplets import KnowledgeTriplets
from .multi2oie_utils import (
install_extensions,
match_extraction_spans_to_wp,
wp2tokid,
wp_span_to_token,
)
from conspiracies.docprocessing.relationextraction.data_classes import (
install_extensions,
DocTriplets,
)


class SpacyRelationExtractor(TrainablePipe):
Expand Down Expand Up @@ -93,10 +96,8 @@ def set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None:
for indices, values in zip(filtered_indices, predictions[key])
]

# Output empty lists if empty doc or no extractions above threshold
# return if no extractions are above threshold
if not predictions["extraction"]:
setattr(doc._, "relation_confidence", []) # type: ignore
setattr(doc._, "relation_triplets", []) # type: ignore
return

# concatenate wordpieces and concatenate extraction span. Handle new extraction
Expand All @@ -121,7 +122,8 @@ def set_annotations(self, doc: Iterable[Doc], predictions: Dict) -> None:
# Set doc level attributes
merged_confidence = [j for i in predictions["confidence"] for j in i]
setattr(doc._, "relation_confidence", merged_confidence) # type: ignore
setattr(doc._, "relation_triplets", aligned_extractions) # type: ignore
triplets = DocTriplets(span_triplets=aligned_extractions, doc=doc)
setattr(doc._, "relation_triplets", triplets) # type: ignore

def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
"""Apply the pipe to a stream of documents.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, List, Tuple, Union

from spacy.tokens import Doc, Span
from spacy.tokens import Doc
from thinc.types import Ragged
from transformers import BertTokenizer
from functools import cache

from conspiracies.docprocessing.relationextraction.data_classes import SpanTriplet


#### Wordpiece <-> spacy alignment functions
def wp2tokid(align: Ragged) -> Dict[int, int]:
Expand Down Expand Up @@ -50,7 +52,7 @@ def wp_span_to_token(
relation_span: List[List[int]],
wp_tokenid_mapping: Dict,
doc: Doc,
) -> List[Tuple[Span, Span, Span]]:
) -> List[SpanTriplet]:
"""Converts the wp span for each relation to spans.
Assumes that relations are contiguous
Expand All @@ -73,7 +75,7 @@ def wp_span_to_token(
relation = token_span_to_spacy_span(relation, doc)
tail = token_span_to_spacy_span(tail, doc)

relations.append((head, relation, tail))
relations.append(SpanTriplet(subject=head, predicate=relation, object=tail))
return relations


Expand Down Expand Up @@ -107,51 +109,6 @@ def match_extraction_spans_to_wp(
return matched_extractions


def span_to_idx(span: Span) -> Tuple[int, int]:
return span.start, span.end


def idx_to_span(idx: Tuple[int, int], doc: Doc) -> Span:
start, end = idx
return doc[start:end]


def install_extensions() -> None:
"""Sets extensions on the SpaCy Doc class.
Relation triplets are stored internally as index tuples, but they
are created with getters and setters that map the index tuples to
and from SpaCy Span objects. Heads, relations and tails are
retrieved from the triplets. Confidence numbers are stored as is.
"""
Doc.set_extension("relation_triplet_idxs", default=None)
Doc.set_extension(
"relation_triplets",
setter=lambda doc, triplets: setattr(
doc._,
"relation_triplet_idxs",
[tuple(span_to_idx(span) for span in triplet) for triplet in triplets],
),
getter=lambda doc: [
tuple(idx_to_span(idx, doc) for idx in triplet_idx)
for triplet_idx in doc._.relation_triplet_idxs
],
)
Doc.set_extension(
"relation_head",
getter=lambda doc: [t[0] for t in doc._.relation_triplets],
)
Doc.set_extension(
"relation_relation",
getter=lambda doc: [t[1] for t in doc._.relation_triplets],
)
Doc.set_extension(
"relation_tail",
getter=lambda doc: [t[2] for t in doc._.relation_triplets],
)
Doc.set_extension("relation_confidence", default=None)


@cache
def get_cached_tokenizer(model_name):
return BertTokenizer.from_pretrained(model_name)
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ def get_single_predicate_idxs(pred_tags):
elif tag.item() == pred_tag2idx["P-I"]:
cur_pred[b_idx + j] = pred_tag2idx["P-I"]
cur_sent_preds.append(cur_pred)
total_pred_tags.append(np.vstack(cur_sent_preds))
if cur_sent_preds:
total_pred_tags.append(np.vstack(cur_sent_preds))
else:
total_pred_tags.append(np.empty(0))
return [torch.from_numpy(pred_tags) for pred_tags in total_pred_tags]


Expand Down
8 changes: 6 additions & 2 deletions tests/test_data/prompt_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List

import spacy

from conspiracies.docprocessing.relationextraction.data_classes import (
install_extensions,
)
from conspiracies.docprocessing.relationextraction.gptprompting import (
DocTriplets,
SpanTriplet,
Expand Down Expand Up @@ -161,7 +165,7 @@ def load_gold_triplets() -> List[Doc]:
span_triplets = [triplet for triplet in span_triplets_ if triplet is not None]

if not Doc.has_extension("relation_triplets"):
Doc.set_extension("relation_triplets", default=[], force=True)
install_extensions()
doc._.relation_triplets = DocTriplets(span_triplets=span_triplets, doc=doc)

# copy them to test with multiple examples.
Expand Down Expand Up @@ -212,7 +216,7 @@ def load_examples() -> List[Doc]:
nlp.add_pipe("sentencizer")

if not Doc.has_extension("relation_triplets"):
Doc.set_extension("relation_triplets", default=[], force=True)
install_extensions(force=True)
examples: List[Doc] = []
for example, triplet_list in [
(example_tweet_1, example_triplets_1),
Expand Down
Loading

0 comments on commit d66d58b

Please sign in to comment.