Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 29, 2024
1 parent 6e3e125 commit f420eff
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 11 deletions.
85 changes: 85 additions & 0 deletions external_components/llm_extraction/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import logging

from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from llama_index.core import ChatPromptTemplate
from llama_index.core.llms import LLM, ChatMessage
from llama_index.llms.openai import OpenAI
from llama_index.program.openai import OpenAIPydanticProgram
from pydantic import ValidationError

# from pydantic import BaseModel, Field
from osm.schemas.metrics_schemas import LLMExtractorMetrics

LLM_MODELS = {"gpt-4o-2024-08-06": OpenAI(model="gpt-4o-2024-08-06")}


logger = logging.getLogger(__name__)
app = FastAPI()


def get_program(llm: LLM) -> OpenAIPydanticProgram:
prompt = ChatPromptTemplate(
message_templates=[
ChatMessage(
role="system",
content=(
"You are an expert at extracting information from scientific publications with a keen eye for details that when combined together allows you to summarize aspects of the publication"
),
),
ChatMessage(
role="user",
content=(
"The llm model is {llm_model}. The publication in xml follows below:\n"
"------\n"
"{xml_content}\n"
"------"
),
),
]
)

program = OpenAIPydanticProgram.from_defaults(
output_cls=LLMExtractorMetrics,
llm=llm,
prompt=prompt,
verbose=True,
)
return program


def extract_with_llm(xml_content: bytes, llm: LLM) -> LLMExtractorMetrics:
program = get_program(llm=llm)
return program(xml_content=xml_content, llm_model=llm.model)


def llm_metric_extraction(
xml_content: bytes,
llm_model: str,
):
return extract_with_llm(xml_content, LLM_MODELS[llm_model])


@app.post("/extract-metrics/", response_model=LLMExtractorMetrics)
async def extract_metrics(
file: UploadFile = File(...), llm_model: str = Query("other")
):
try:
xml_content = await file.read()
if not xml_content:
raise NotImplementedError(
"""For now the XML content must be provided. Check the output of
the parsing stage."""
)
for ii in range(5):
try:
metrics = llm_metric_extraction(xml_content, llm_model)
except ValidationError as e:
# retry if it is just a validation error (the LLM can try harder next time)
print("Validation error:", e)
break

logger.info(metrics)
return metrics

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
8 changes: 8 additions & 0 deletions osm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import logging
import os
import re
import time
import types
from pathlib import Path
Expand Down Expand Up @@ -152,3 +153,10 @@ def flatten_dict(d):
else:
items.append((k, v))
return dict(items)


def camel_to_snake(name: str) -> str:
# Replace capital letters with underscore + lowercase letter
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
# Handle cases where a lowercase is followed by an uppercase letter
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
8 changes: 7 additions & 1 deletion osm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from osm._utils import DEFAULT_OUTPUT_DIR, _existing_file, _setup, compose_down
from osm.pipeline.core import Pipeline, Savers
from osm.pipeline.extractors import RTransparentExtractor
from osm.pipeline.extractors import LLMExtractor, RTransparentExtractor
from osm.pipeline.parsers import NoopParser, PMCParser, ScienceBeamParser
from osm.pipeline.savers import FileSaver, JSONSaver, OSMSaver

Expand All @@ -13,6 +13,7 @@
}
EXTRACTORS = {
"rtransparent": RTransparentExtractor,
"llm_extractor": LLMExtractor,
}


Expand Down Expand Up @@ -51,6 +52,11 @@ def parse_args():
nargs="+",
help="Select the tool for extracting the output metrics. Default is 'rtransparent'.",
)
parser.add_argument(
"--llm_model",
default="gpt-4o-2024-08-06",
help="Specify the model to use for LLM extraction.",
)
parser.add_argument(
"--comment",
required=False,
Expand Down
6 changes: 4 additions & 2 deletions osm/pipeline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,17 @@ def __init__(
self.xml_path = xml_path
self.metrics_path = metrics_path

def run(self, user_managed_compose: bool = False):
def run(self, user_managed_compose: bool = False, llm_model: str = None):
for parser in self.parsers:
parsed_data = parser.run(
self.file_data, user_managed_compose=user_managed_compose
)
if isinstance(parsed_data, bytes):
self.savers.save_file(parsed_data, self.xml_path)
for extractor in self.extractors:
extracted_metrics = extractor.run(parsed_data, parser=parser.name)
extracted_metrics = extractor.run(
parsed_data, parser=parser.name, llm_model=llm_model
)
self.savers.save_osm(
data=self.file_data,
metrics=extracted_metrics,
Expand Down
25 changes: 24 additions & 1 deletion osm/pipeline/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class RTransparentExtractor(Component):
def _run(self, data: bytes, parser: str = None) -> dict:
def _run(self, data: bytes, parser: str = None, **kwargs) -> dict:
self.sample = LongBytes(data)

# Prepare the file to be sent as a part of form data
Expand All @@ -38,6 +38,29 @@ def _run(self, data: bytes, parser: str = None) -> dict:
response.raise_for_status()


class LLMExtractor(Component):
def _run(self, data: bytes, llm_model: str = None, **kwargs) -> dict:
llm_model = llm_model or kwargs.get("llm_model", "gpt-4o-2024-08-06")
self.sample = LongBytes(data)

# Prepare the file to be sent as a part of form data
files = {"file": ("input.xml", io.BytesIO(data), "application/xml")}

# Send the request with the file
response = requests.post(
"http://localhost:8072/extract-metrics/",
files=files,
params={"llm_model": llm_model},
)

if response.status_code == 200:
metrics = response.json()
return metrics
else:
logger.error(f"Error: {response.text}")
response.raise_for_status()


# import psutil
# # Adjust the logging level for rpy2
# rpy2_logger = logging.getLogger("rpy2")
Expand Down
4 changes: 2 additions & 2 deletions osm/pipeline/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class NoopParser(Component):
"""Used if the input is xml and so needs no parsing."""

def _run(self, data: bytes) -> bytes:
def _run(self, data: bytes, **kwargs) -> bytes:
return data


Expand All @@ -30,7 +30,7 @@ class PMCParser(NoopParser):


class ScienceBeamParser(Component):
def _run(self, data: bytes, user_managed_compose=False) -> str:
def _run(self, data: bytes, user_managed_compose=False, **kwargs) -> str:
self.sample = LongBytes(data)
headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"}
files = {"file": ("input.pdf", io.BytesIO(data), "application/pdf")}
Expand Down
1 change: 1 addition & 0 deletions osm/pipeline/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _run(self, data: bytes, metrics: dict, components: list[schemas.Component]):
raise e
try:
# Validate the payload
breakpoint()
validated_data = schemas.Invocation(**payload)
# If validation passes, send POST request to OSM API. ID is not
# serializable but can be excluded and created by the DB. All types
Expand Down
82 changes: 81 additions & 1 deletion osm/schemas/metrics_schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from odmantic import EmbeddedModel
from odmantic import EmbeddedModel, Field
from pydantic import field_serializer, field_validator

from osm._utils import coerce_to_string
Expand Down Expand Up @@ -218,3 +218,83 @@ class ManualAnnotationNIMHDSST(EmbeddedModel):
manual_data_statements: Optional[str]
Notes: Optional[LongStr]
PMID_raw: Optional[int]

class LLMExtractorMetrics(EmbeddedModel):
"""
Model for extracting information from scientific publications. These metrics
are a summary of the publications adherence to transparent or open
scientific practices.
Many unavailable identifiers (PMID, PMCID etc) can be found using pubmed: https://pubmed.ncbi.nlm.nih.gov/advanced/
"""

llm_model: str = Field(
description="Exact verion of the llm model used to generate the data (not in publication itself but known by the model) e.g. GPT_4o_2024_08_06"
)
year: int = Field(
description="Best attempt at extracting the year of the publication or use the int 9999",
)
journal: str = Field(description="The journal in which the paper was published")
article_type: list[str] = Field(
description="The type of article e.g. research article, review, erratum, meta-analysis etc.",
)
country: list[str] = Field(
description="The countries of the affiliations of the authors",
)
institute: list[str] = Field(
description="The institutes of the affiliations of the authors",
)
doi: str = Field(description="The DOI of the paper")
pmid: int = Field(
description="The PMID of the paper, use the integer 0 if one cannot be found"
)
pmcid: int = Field(
description="The PMCID of the paper, use the integer 0 if one cannot be found"
)
title: str = Field(description="The title of the paper")
authors: list[str] = Field(description="The authors of the paper")
publisher: str = Field(description="The publisher of the paper")
is_open_code: bool = Field(
description="Whether there is evidence that the code used for analysis in the paper has been shared online",
)
code_sharing_statement: list[str] = Field(
description="The statement in the paper that indicates whether the code used for analysis has been shared online",
)
is_open_data: bool = Field(
description="Whether there is evidence that the data used for analysis in the paper has been shared online",
)
data_sharing_statement: list[str] = Field(
description="The statement in the paper that indicates whether the data used for analysis has been shared online",
)
data_repository_url: str = Field(
description="The URL of the repository where the data can be found"
)
dataset_unique_identifier: list[str] = Field(
description="Any unique identifiers the dataset may have"
)
code_repository_url: str = Field(
description="The URL of the repository where the code and data can be found"
)
has_coi_statement: bool = Field(
description="Whether there is a conflict of interest statement in the paper",
)
coi_statement: list[str] = Field(
description="The conflict of interest statement in the paper"
)
funder: list[str] = Field(
description="The funders of the research, may contain multiple funders",
)
has_funding_statement: bool = Field(
description="Whether there is a funding statement in the paper"
)
funding_statement: list[str] = Field(
description="The funding statement in the paper"
)
has_registration_statement: bool = Field(
description="Whether there is a registration statement in the paper",
)
registration_statement: list[str] = Field(
description="The registration statement in the paper"
)
reasoning_steps: list[str] = Field(
description="The reasoning steps used to extract the information from the paper",
)
8 changes: 4 additions & 4 deletions osm/schemas/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from osm._utils import coerce_to_string

from .custom_fields import LongBytes
from .metrics_schemas import ManualAnnotationNIMHDSST, RtransparentMetrics

from .metrics_schemas import ManualAnnotationNIMHDSST, RtransparentMetrics, LLMExtractorMetrics

class Component(EmbeddedModel):
model_config = {
Expand Down Expand Up @@ -71,8 +70,9 @@ class Invocation(Model):
"""

model_config = {"extra": "forbid"}
metrics: Union[RtransparentMetrics | ManualAnnotationNIMHDSST]
metrics_group: str
manual_annotation_nimhdsst: Optional[ManualAnnotationNIMHDSST] = None
llm_extractor_metrics: Optional[LLMExtractorMetrics] = None
rtransparent_metrics: Optional[RtransparentMetrics] = None
components: Optional[list[Component]] = []
work: Work
client: Client
Expand Down

0 comments on commit f420eff

Please sign in to comment.