From 5eba967d7d98658a1d31f3a830a747a102701d88 Mon Sep 17 00:00:00 2001 From: leej3 Date: Wed, 28 Aug 2024 18:06:39 +0100 Subject: [PATCH] fixup --- .dockerignore | 1 + .gitignore | 1 + external_components/llm_extraction/app.py | 18 +++++++++++++----- osm/pipeline/core.py | 6 ++++-- osm/pipeline/extractors.py | 5 +++-- osm/pipeline/parsers.py | 4 ++-- osm/pipeline/savers.py | 1 + osm/schemas/metrics_schemas.py | 2 +- osm/schemas/schemas.py | 6 +++--- web/dashboard/Dockerfile | 2 +- 10 files changed, 30 insertions(+), 16 deletions(-) diff --git a/.dockerignore b/.dockerignore index 3fa8c86b..7209ca22 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,2 @@ .terraform +tempdata diff --git a/.gitignore b/.gitignore index e6f0a174..d921bad8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Ignore the virtual environment directory _version.py +dashboard_data node_modules venv/ *coverage* diff --git a/external_components/llm_extraction/app.py b/external_components/llm_extraction/app.py index 3ed6d643..5d59612b 100644 --- a/external_components/llm_extraction/app.py +++ b/external_components/llm_extraction/app.py @@ -5,9 +5,10 @@ from llama_index.core.llms import LLM, ChatMessage from llama_index.llms.openai import OpenAI from llama_index.program.openai import OpenAIPydanticProgram +from pydantic import ValidationError # from pydantic import BaseModel, Field -from osm.schemas.metrics_schemas import LLMExtractor +from osm.schemas.metrics_schemas import LLMExtractorMetrics LLM_MODELS = {"gpt-4o-2024-08-06": OpenAI(model="gpt-4o-2024-08-06")} @@ -38,7 +39,7 @@ def get_program(llm: LLM) -> OpenAIPydanticProgram: ) program = OpenAIPydanticProgram.from_defaults( - output_cls=LLMExtractor, + output_cls=LLMExtractorMetrics, llm=llm, prompt=prompt, verbose=True, @@ -46,7 +47,7 @@ def get_program(llm: LLM) -> OpenAIPydanticProgram: return program -def extract_with_llm(xml_content: bytes, llm: LLM) -> LLMExtractor: +def extract_with_llm(xml_content: bytes, llm: LLM) -> LLMExtractorMetrics: program = get_program(llm=llm) return program(xml_content=xml_content, llm_model=llm.model) @@ -58,7 +59,7 @@ def llm_metric_extraction( return extract_with_llm(xml_content, LLM_MODELS[llm_model]) -@app.post("/extract-metrics/", response_model=LLMExtractor) +@app.post("/extract-metrics/", response_model=LLMExtractorMetrics) async def extract_metrics( file: UploadFile = File(...), llm_model: str = Query("other") ): @@ -69,7 +70,14 @@ async def extract_metrics( """For now the XML content must be provided. Check the output of the parsing stage.""" ) - metrics = llm_metric_extraction(xml_content, llm_model) + for ii in range(5): + try: + metrics = llm_metric_extraction(xml_content, llm_model) + except ValidationError as e: + # retry if it is just a validation error (the LLM can try harder next time) + print("Validation error:", e) + break + logger.info(metrics) return metrics diff --git a/osm/pipeline/core.py b/osm/pipeline/core.py index 16e38ba9..952efb15 100644 --- a/osm/pipeline/core.py +++ b/osm/pipeline/core.py @@ -99,7 +99,7 @@ def __init__( self.xml_path = xml_path self.metrics_path = metrics_path - def run(self, user_managed_compose: bool = False): + def run(self, user_managed_compose: bool = False, llm_model: str = None): for parser in self.parsers: parsed_data = parser.run( self.file_data, user_managed_compose=user_managed_compose @@ -107,7 +107,9 @@ def run(self, user_managed_compose: bool = False): if isinstance(parsed_data, bytes): self.savers.save_file(parsed_data, self.xml_path) for extractor in self.extractors: - extracted_metrics = extractor.run(parsed_data, parser=parser.name) + extracted_metrics = extractor.run( + parsed_data, parser=parser.name, llm_model=llm_model + ) self.savers.save_osm( data=self.file_data, metrics=extracted_metrics, diff --git a/osm/pipeline/extractors.py b/osm/pipeline/extractors.py index 109ca198..686f4b49 100644 --- a/osm/pipeline/extractors.py +++ b/osm/pipeline/extractors.py @@ -11,7 +11,7 @@ class RTransparentExtractor(Component): - def _run(self, data: bytes, parser: str = None) -> dict: + def _run(self, data: bytes, parser: str = None, **kwargs) -> dict: self.sample = LongBytes(data) # Prepare the file to be sent as a part of form data @@ -39,7 +39,8 @@ def _run(self, data: bytes, parser: str = None) -> dict: class LLMExtractor(Component): - def _run(self, data: bytes, llm_model: str = None) -> dict: + def _run(self, data: bytes, llm_model: str = None, **kwargs) -> dict: + llm_model = llm_model or kwargs.get("llm_model", "gpt-4o-2024-08-06") self.sample = LongBytes(data) # Prepare the file to be sent as a part of form data diff --git a/osm/pipeline/parsers.py b/osm/pipeline/parsers.py index 133bb4bc..8dc59d60 100644 --- a/osm/pipeline/parsers.py +++ b/osm/pipeline/parsers.py @@ -13,7 +13,7 @@ class NoopParser(Component): """Used if the input is xml and so needs no parsing.""" - def _run(self, data: bytes) -> bytes: + def _run(self, data: bytes, **kwargs) -> bytes: return data @@ -30,7 +30,7 @@ class PMCParser(NoopParser): class ScienceBeamParser(Component): - def _run(self, data: bytes, user_managed_compose=False) -> str: + def _run(self, data: bytes, user_managed_compose=False, **kwargs) -> str: self.sample = LongBytes(data) headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"} files = {"file": ("input.pdf", io.BytesIO(data), "application/pdf")} diff --git a/osm/pipeline/savers.py b/osm/pipeline/savers.py index 2f264736..56237054 100644 --- a/osm/pipeline/savers.py +++ b/osm/pipeline/savers.py @@ -108,6 +108,7 @@ def _run(self, data: bytes, metrics: dict, components: list[schemas.Component]): ) raise e try: + breakpoint() # Validate the payload validated_data = schemas.Invocation(**payload) # If validation passes, send POST request to OSM API. ID is not diff --git a/osm/schemas/metrics_schemas.py b/osm/schemas/metrics_schemas.py index 8f7d483d..b1d44ba8 100644 --- a/osm/schemas/metrics_schemas.py +++ b/osm/schemas/metrics_schemas.py @@ -208,7 +208,7 @@ def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]: return value.get_value() if value else None -class LLMExtractor(EmbeddedModel): +class LLMExtractorMetrics(EmbeddedModel): """ Model for extracting information from scientific publications. These metrics are a summary of the publications adherence to transparent or open diff --git a/osm/schemas/schemas.py b/osm/schemas/schemas.py index 43efbbd1..36923b6b 100644 --- a/osm/schemas/schemas.py +++ b/osm/schemas/schemas.py @@ -1,6 +1,6 @@ import base64 import datetime -from typing import Optional +from typing import Optional, Union import pandas as pd from odmantic import EmbeddedModel, Field, Model @@ -9,7 +9,7 @@ from osm._utils import coerce_to_string from .custom_fields import LongBytes -from .metrics_schemas import RtransparentMetrics +from .metrics_schemas import LLMExtractorMetrics, RtransparentMetrics class Component(EmbeddedModel): @@ -71,7 +71,7 @@ class Invocation(Model): """ model_config = {"extra": "forbid"} - metrics: RtransparentMetrics + metrics: Union[RtransparentMetrics, LLMExtractorMetrics] components: Optional[list[Component]] = [] work: Work client: Client diff --git a/web/dashboard/Dockerfile b/web/dashboard/Dockerfile index 7daf4b25..a6517386 100644 --- a/web/dashboard/Dockerfile +++ b/web/dashboard/Dockerfile @@ -8,7 +8,7 @@ RUN pip install panel pymongo odmantic pandas pydantic[email] pyarrow RUN mkdir -p /opt/data ENV LOCAL_DATA_PATH=/opt/data/matches.parquet -COPY ./tempdata/matches.parquet /opt/data/matches.parquet +COPY ./dashboard_data/matches.parquet /opt/data/matches.parquet RUN mkdir -p /opt/osm COPY pyproject.toml /opt/osm