Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include only input-relevant named entities when producing output #382

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions src/ontogpt/engines/knowledge_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main Knowledge Extractor class."""

import logging
import re
from abc import ABC
Expand Down Expand Up @@ -41,11 +42,13 @@
# if it's not installed
try:
from ontogpt.clients import OpenAIClient, GPT4AllClient

CLIENT_TYPES = Union[OpenAIClient, GPT4AllClient]
except ImportError:
logger.warning("GPT4All client not available. GPT4All support will be disabled.")
from ontogpt.clients import OpenAIClient
CLIENT_TYPES = OpenAIClient # type: ignore

CLIENT_TYPES = OpenAIClient # type: ignore

# annotation metamodel
ANNOTATION_KEY_PROMPT = "prompt"
Expand All @@ -54,24 +57,6 @@
ANNOTATION_KEY_RECURSE = "ner.recurse"
ANNOTATION_KEY_EXAMPLES = "prompt.examples"

# TODO: introspect
# TODO: move this to its own module
DATAMODELS = [
"biological_process.BiologicalProcess",
"biotic_interaction.BioticInteraction",
"cell_type.CellTypeDocument",
"ctd.ChemicalToDiseaseDocument",
"diagnostic_procedure.DiagnosticProceduretoPhenotypeAssociation",
"drug.DrugMechanism",
"environmental_sample.Study",
"gocam.GoCamAnnotations",
"mendelian_disease.MendelianDisease",
"phenotype.Trait",
"reaction.Reaction",
"recipe.Recipe",
"treatment.DiseaseTreatmentSummary",
]


def chunk_text(text: str, window_size=3) -> Iterator[str]:
"""Chunk text into windows of sentences."""
Expand Down Expand Up @@ -152,7 +137,11 @@ class KnowledgeEngine(ABC):
"""Min proportion of overlap in characters between text and grounding. TODO: use tokenization"""

named_entities: List[NamedEntity] = field(default_factory=list)
"""Cache of all named entities"""
"""Cache of all named entities. This is not written to output directly as each input
has its own corresponding named entities."""

extracted_named_entities: List[NamedEntity] = field(default_factory=list)
"""Temporary cache of named entities, to be cleared between extractions."""

auto_prefix: str = ""
"""If set then non-normalized named entities will be mapped to this prefix"""
Expand Down Expand Up @@ -349,14 +338,18 @@ def normalize_named_entity(self, text: str, range: ElementName) -> str:
logger.info(f"Grounding {text} to {obj_id}; next step is to normalize")
for normalized_id in self.normalize_identifier(obj_id, cls):
if not any(e for e in self.named_entities if e.id == normalized_id):
self.named_entities.append(NamedEntity(id=normalized_id, label=text))
ne = NamedEntity(id=normalized_id, label=text)
self.named_entities.append(ne)
self.extracted_named_entities.append(ne)
logger.info(f"Normalized {text} with {obj_id} to {normalized_id}")
return normalized_id
logger.info(f"Could not ground and normalize {text} to {cls.name}")
if self.auto_prefix:
obj_id = f"{self.auto_prefix}:{quote(text)}"
if not any(e for e in self.named_entities if e.id == obj_id):
self.named_entities.append(NamedEntity(id=obj_id, label=text))
ne = NamedEntity(id=normalized_id, label=text)
self.named_entities.append(ne)
self.extracted_named_entities.append(ne)
else:
obj_id = text
if ANNOTATION_KEY_RECURSE in cls.annotations:
Expand All @@ -370,6 +363,7 @@ def normalize_named_entity(self, text: str, range: ElementName) -> str:
except ValueError as e:
logger.error(f"No id for {obj} {e}")
self.named_entities.append(obj)
self.extracted_named_entities.append(obj)
return obj_id

def is_valid_identifier(self, input_id: str, cls: ClassDefinition) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion src/ontogpt/engines/spires_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def extract_from_text(
:param object: optional stub object
:return:
"""
self.extracted_named_entities = [] # Clear the named entity buffer

if self.sentences_per_window:
chunks = chunk_text(text, self.sentences_per_window)
extracted_object = None
Expand All @@ -95,12 +97,15 @@ def extract_from_text(
extracted_object = self.parse_completion_payload(
raw_text, cls, object=object # type: ignore
)

return ExtractionResult(
input_text=text,
raw_completion_output=raw_text,
prompt=self.last_prompt,
extracted_object=extracted_object,
named_entities=self.named_entities,
named_entities=self.extracted_named_entities,
# Note these are the named entities from the last extraction,
# not the full list of all named entities across all extractions
)

def _extract_from_text_to_dict(self, text: str, cls: ClassDefinition = None) -> RESPONSE_DICT:
Expand Down
Loading