Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 28, 2024
1 parent d98faab commit 5eba967
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 16 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.terraform
tempdata
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Ignore the virtual environment directory
_version.py
dashboard_data
node_modules
venv/
*coverage*
Expand Down
18 changes: 13 additions & 5 deletions external_components/llm_extraction/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from llama_index.core.llms import LLM, ChatMessage
from llama_index.llms.openai import OpenAI
from llama_index.program.openai import OpenAIPydanticProgram
from pydantic import ValidationError

# from pydantic import BaseModel, Field
from osm.schemas.metrics_schemas import LLMExtractor
from osm.schemas.metrics_schemas import LLMExtractorMetrics

LLM_MODELS = {"gpt-4o-2024-08-06": OpenAI(model="gpt-4o-2024-08-06")}

Expand Down Expand Up @@ -38,15 +39,15 @@ def get_program(llm: LLM) -> OpenAIPydanticProgram:
)

program = OpenAIPydanticProgram.from_defaults(
output_cls=LLMExtractor,
output_cls=LLMExtractorMetrics,
llm=llm,
prompt=prompt,
verbose=True,
)
return program


def extract_with_llm(xml_content: bytes, llm: LLM) -> LLMExtractor:
def extract_with_llm(xml_content: bytes, llm: LLM) -> LLMExtractorMetrics:
program = get_program(llm=llm)
return program(xml_content=xml_content, llm_model=llm.model)

Expand All @@ -58,7 +59,7 @@ def llm_metric_extraction(
return extract_with_llm(xml_content, LLM_MODELS[llm_model])


@app.post("/extract-metrics/", response_model=LLMExtractor)
@app.post("/extract-metrics/", response_model=LLMExtractorMetrics)
async def extract_metrics(
file: UploadFile = File(...), llm_model: str = Query("other")
):
Expand All @@ -69,7 +70,14 @@ async def extract_metrics(
"""For now the XML content must be provided. Check the output of
the parsing stage."""
)
metrics = llm_metric_extraction(xml_content, llm_model)
for ii in range(5):
try:
metrics = llm_metric_extraction(xml_content, llm_model)
except ValidationError as e:
# retry if it is just a validation error (the LLM can try harder next time)
print("Validation error:", e)
break

logger.info(metrics)
return metrics

Expand Down
6 changes: 4 additions & 2 deletions osm/pipeline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,17 @@ def __init__(
self.xml_path = xml_path
self.metrics_path = metrics_path

def run(self, user_managed_compose: bool = False):
def run(self, user_managed_compose: bool = False, llm_model: str = None):
for parser in self.parsers:
parsed_data = parser.run(
self.file_data, user_managed_compose=user_managed_compose
)
if isinstance(parsed_data, bytes):
self.savers.save_file(parsed_data, self.xml_path)
for extractor in self.extractors:
extracted_metrics = extractor.run(parsed_data, parser=parser.name)
extracted_metrics = extractor.run(
parsed_data, parser=parser.name, llm_model=llm_model
)
self.savers.save_osm(
data=self.file_data,
metrics=extracted_metrics,
Expand Down
5 changes: 3 additions & 2 deletions osm/pipeline/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class RTransparentExtractor(Component):
def _run(self, data: bytes, parser: str = None) -> dict:
def _run(self, data: bytes, parser: str = None, **kwargs) -> dict:
self.sample = LongBytes(data)

# Prepare the file to be sent as a part of form data
Expand Down Expand Up @@ -39,7 +39,8 @@ def _run(self, data: bytes, parser: str = None) -> dict:


class LLMExtractor(Component):
def _run(self, data: bytes, llm_model: str = None) -> dict:
def _run(self, data: bytes, llm_model: str = None, **kwargs) -> dict:
llm_model = llm_model or kwargs.get("llm_model", "gpt-4o-2024-08-06")
self.sample = LongBytes(data)

# Prepare the file to be sent as a part of form data
Expand Down
4 changes: 2 additions & 2 deletions osm/pipeline/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class NoopParser(Component):
"""Used if the input is xml and so needs no parsing."""

def _run(self, data: bytes) -> bytes:
def _run(self, data: bytes, **kwargs) -> bytes:
return data


Expand All @@ -30,7 +30,7 @@ class PMCParser(NoopParser):


class ScienceBeamParser(Component):
def _run(self, data: bytes, user_managed_compose=False) -> str:
def _run(self, data: bytes, user_managed_compose=False, **kwargs) -> str:
self.sample = LongBytes(data)
headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"}
files = {"file": ("input.pdf", io.BytesIO(data), "application/pdf")}
Expand Down
1 change: 1 addition & 0 deletions osm/pipeline/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _run(self, data: bytes, metrics: dict, components: list[schemas.Component]):
)
raise e
try:
breakpoint()
# Validate the payload
validated_data = schemas.Invocation(**payload)
# If validation passes, send POST request to OSM API. ID is not
Expand Down
2 changes: 1 addition & 1 deletion osm/schemas/metrics_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]:
return value.get_value() if value else None


class LLMExtractor(EmbeddedModel):
class LLMExtractorMetrics(EmbeddedModel):
"""
Model for extracting information from scientific publications. These metrics
are a summary of the publications adherence to transparent or open
Expand Down
6 changes: 3 additions & 3 deletions osm/schemas/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
import datetime
from typing import Optional
from typing import Optional, Union

import pandas as pd
from odmantic import EmbeddedModel, Field, Model
Expand All @@ -9,7 +9,7 @@
from osm._utils import coerce_to_string

from .custom_fields import LongBytes
from .metrics_schemas import RtransparentMetrics
from .metrics_schemas import LLMExtractorMetrics, RtransparentMetrics


class Component(EmbeddedModel):
Expand Down Expand Up @@ -71,7 +71,7 @@ class Invocation(Model):
"""

model_config = {"extra": "forbid"}
metrics: RtransparentMetrics
metrics: Union[RtransparentMetrics, LLMExtractorMetrics]
components: Optional[list[Component]] = []
work: Work
client: Client
Expand Down
2 changes: 1 addition & 1 deletion web/dashboard/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN pip install panel pymongo odmantic pandas pydantic[email] pyarrow

RUN mkdir -p /opt/data
ENV LOCAL_DATA_PATH=/opt/data/matches.parquet
COPY ./tempdata/matches.parquet /opt/data/matches.parquet
COPY ./dashboard_data/matches.parquet /opt/data/matches.parquet

RUN mkdir -p /opt/osm
COPY pyproject.toml /opt/osm
Expand Down

0 comments on commit 5eba967

Please sign in to comment.