Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Expand and improved schema usage and tests #56

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,52 +1,48 @@
name: Tests

on: push
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/[email protected]
tox-tests:
runs-on: ubuntu-latest
- uses: actions/setup-python@v4
with:
python-version: '3.11'
- uses: pre-commit/[email protected]

pytest:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]

steps:

- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Cache pip
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
pip install uv
uv pip install --system -r <(uv pip compile --all-extras pyproject.toml)

- name: Start ScienceBeam Docker container
- name: Start Docker stack
run: |
docker run -d --rm -p 8070 elifesciences/sciencebeam-parser
docker compose -f compose.yaml -f compose.development.override.yaml up -d --build

- name: Run tests
run: |
tox
run: pytest tests

- name: Test packaging
run: |
tox -e .package
- name: Stop Docker stack
if: always()
run: docker compose -f compose.yaml -f compose.development.override.yaml down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ osm_output
.public_dns
tempdata
.aider*
build-cache
24 changes: 22 additions & 2 deletions compose.development.override.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,25 @@ services:
volumes:
- ./external_components/rtransparent:/app

llm_extraction:
container_name: llm_extraction
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY:-NOKEY}
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-NOKEY}
build:
context: .
dockerfile: ./external_components/llm_extraction/Dockerfile
volumes:
- ./external_components/llm_extraction:/app
depends_on:
- base
develop:
watch:
- action: sync+restart
path: ./osm
target: /opt/osm/osm
ignore:
- "**/__pycache__"

############ Development images ############

Expand All @@ -22,6 +41,7 @@ services:
- MONGO_INITDB_DATABASE=osm

base:
container_name: base
command: ["echo", "base image"]
image: nimhdsst/osm_base:latest
build:
Expand Down Expand Up @@ -50,7 +70,7 @@ services:
path: ./osm
target: /opt/osm/osm
ignore:
- __pycache__
- __pycache__/*

web_api:
container_name: web_api
Expand All @@ -75,4 +95,4 @@ services:
path: ./osm
target: /opt/osm/osm
ignore:
- __pycache__
- __pycache__/*
5 changes: 5 additions & 0 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ services:
image: nimhdsst/rtransparent:staging
ports:
- "8071:8071"
llm_extraction:
container_name: llm_extraction
image: nimhdsst/llm_extraction:staging
ports:
- "8072:8072"
91 changes: 91 additions & 0 deletions external_components/llm_extraction/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging

from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from llama_index.core import ChatPromptTemplate
from llama_index.core.llms import LLM, ChatMessage
from llama_index.llms.openai import OpenAI
from llama_index.program.openai import OpenAIPydanticProgram
from pydantic import ValidationError

# from pydantic import BaseModel, Field
from osm.schemas.metrics_schemas import LLMExtractorMetrics

LLM_MODELS = {
# "gpt-4o-2024-08-06": OpenRouter(
# api_key=os.environ["OPENROUTER_API_KEY"],
# model="openai/chatgpt-4o-latest",
# )
"gpt-4o-2024-08-06": OpenAI(model="gpt-4o-2024-08-06")
}


logger = logging.getLogger(__name__)
app = FastAPI()


def get_program(llm: LLM) -> OpenAIPydanticProgram:
prompt = ChatPromptTemplate(
message_templates=[
ChatMessage(
role="system",
content=(
"You are an expert at extracting information from scientific publications with a keen eye for details that when combined together allows you to summarize aspects of the publication"
),
),
ChatMessage(
role="user",
content=(
"The llm model is {llm_model}. The publication in xml follows below:\n"
"------\n"
"{xml_content}\n"
"------"
),
),
]
)

program = OpenAIPydanticProgram.from_defaults(
output_cls=LLMExtractorMetrics,
llm=llm,
prompt=prompt,
verbose=True,
)
return program


def extract_with_llm(xml_content: bytes, llm: LLM) -> LLMExtractorMetrics:
program = get_program(llm=llm)
return program(xml_content=xml_content, llm_model=llm.model)


def llm_metric_extraction(
xml_content: bytes,
llm_model: str,
):
return extract_with_llm(xml_content, LLM_MODELS[llm_model])


@app.post("/extract-metrics/", response_model=LLMExtractorMetrics)
async def extract_metrics(
file: UploadFile = File(...), llm_model: str = Query("other")
):
try:
xml_content = await file.read()
if not xml_content:
raise NotImplementedError(
"""For now the XML content must be provided. Check the output of
the parsing stage."""
)
for ii in range(5):
try:
metrics = llm_metric_extraction(xml_content, llm_model)
except ValidationError as e:
# retry if it is just a validation error (the LLM can try harder next time)
print("Validation error:", e)
break

logger.info(metrics)
return metrics

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
2 changes: 1 addition & 1 deletion osm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def generate_version_file():
import pkg_resources

if os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION_FOR_OSM"):
version = os.environ["SETUPTOOLS_SCM_PRETEND_VERSION_FOR_OSM"]
version = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION_FOR_OSM")
else:
version = pkg_resources.get_distribution("osm").version
version_file_content = f"version = '{version}'\n"
Expand Down
13 changes: 11 additions & 2 deletions osm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from osm._utils import DEFAULT_OUTPUT_DIR, _existing_file, _setup, compose_down
from osm.pipeline.core import Pipeline, Savers
from osm.pipeline.extractors import RTransparentExtractor
from osm.pipeline.extractors import LLMExtractor, RTransparentExtractor
from osm.pipeline.parsers import NoopParser, PMCParser, ScienceBeamParser
from osm.pipeline.savers import FileSaver, JSONSaver, OSMSaver

Expand All @@ -13,6 +13,7 @@
}
EXTRACTORS = {
"rtransparent": RTransparentExtractor,
"llm_extractor": LLMExtractor,
}


Expand Down Expand Up @@ -51,6 +52,11 @@ def parse_args():
nargs="+",
help="Select the tool for extracting the output metrics. Default is 'rtransparent'.",
)
parser.add_argument(
"--llm_model",
default="gpt-4o-2024-08-06",
help="Specify the model to use for LLM extraction.",
)
parser.add_argument(
"--comment",
required=False,
Expand Down Expand Up @@ -93,7 +99,10 @@ def main():
),
),
)
pipeline.run(user_managed_compose=args.user_managed_compose)
pipeline.run(
user_managed_compose=args.user_managed_compose,
llm_model=args.llm_model,
)
finally:
if not args.user_managed_compose:
compose_down()
Expand Down
29 changes: 29 additions & 0 deletions osm/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from motor.motor_asyncio import AsyncIOMotorClient

DB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017")


async def db_init(db_uri: str = DB_URI):
# Create Motor client
pass
client = get_mongo_client()
# yield client
# client.close()
db = client.get_database("osm")
return db


def get_mongo_client():
return AsyncIOMotorClient(DB_URI)


def get_mongo_db(mongo_client: AsyncIOMotorClient | None = None):
if mongo_client is None:
mongo_client = get_mongo_client()
return mongo_client.get_database("osm")


def get_mongo_session(mongo_client: AsyncIOMotorClient):
return mongo_client.start_session()
15 changes: 12 additions & 3 deletions osm/pipeline/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional

from osm import schemas
from osm.db import db_init


class Component(ABC):
Expand Down Expand Up @@ -99,19 +101,26 @@ def __init__(
self.xml_path = xml_path
self.metrics_path = metrics_path

def run(self, user_managed_compose: bool = False):
def run(self, user_managed_compose: bool = False, llm_model: str = None):
try:
asyncio.run(db_init())
except Exception as e:
print(e)
raise EnvironmentError("Could not connect to the OSM database.")
for parser in self.parsers:
parsed_data = parser.run(
self.file_data, user_managed_compose=user_managed_compose
)
if isinstance(parsed_data, bytes):
self.savers.save_file(parsed_data, self.xml_path)
for extractor in self.extractors:
extracted_metrics = extractor.run(parsed_data, parser=parser.name)
extracted_metrics = extractor.run(
parsed_data, parser=parser.name, llm_model=llm_model
)
self.savers.save_osm(
data=self.file_data,
metrics=extracted_metrics,
components=[*self.parsers, *self.extractors, *self.savers],
components=[parser, extractor, *self.savers],
)
self.savers.save_json(extracted_metrics, self.metrics_path)

Expand Down
Loading
Loading