Skip to content

Commit

Permalink
Merge pull request #43 from nimh-dsst/cli-improvements
Browse files Browse the repository at this point in the history
Cli improvements
  • Loading branch information
leej3 authored Aug 26, 2024
2 parents ec8234b + d7d2250 commit 9faa99a
Show file tree
Hide file tree
Showing 15 changed files with 200 additions and 105 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ osm -f path/to/pdf-or-xml2 -u uuid2 --user-managed-compose

# Contributing

N.B. On Apple silicon you must use emulation i.e. `export DOCKER_DEFAULT_PLATFORM=linux/amd64`

If you wish to contribute to this project you can set up a development environment with the following:

```
Expand Down
40 changes: 5 additions & 35 deletions osm/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import argparse
import base64
import hashlib
import logging
import os
import shlex
Expand All @@ -12,8 +10,6 @@
import pandas as pd
import requests

from osm._version import __version__

DEFAULT_OUTPUT_DIR = "./osm_output"
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,35 +58,6 @@ def get_compute_context_id():
return hash(f"{os.environ.get('HOSTNAME')}_{os.environ.get('USERNAME')}")


def _upload_data(args, file_in, xml, metrics, components):
osm_api = os.environ.get("OSM_API", "http://localhost:80")

payload = {
"osm_version": __version__,
"user_comment": args.comment,
"work": {
"user_defined_id": args.uid,
"filename": args.file.name,
"file": base64.b64encode(file_in).decode("utf-8"),
"content_hash": hashlib.sha256(file_in).hexdigest(),
},
"client": {
"compute_context_id": get_compute_context_id(),
"email": args.email,
},
"metrics": metrics,
"components": components,
}
# Send POST request to OSM API
response = requests.put(f"{osm_api}/upload", json=payload)

# Check response status code
if response.status_code == 200:
print("Invocation data uploaded successfully")
else:
print(f"Failed to upload invocation data: \n {response.text}")


def wait_for_containers():
while True:
try:
Expand Down Expand Up @@ -128,8 +95,10 @@ def _setup(args):
raise FileExistsError(xml_path)
elif args.filepath.name.endswith(".xml"):
logger.warning(
"""The input file is an xml file. Skipping the pdf to text
conversion and so ignoring requested parsers."""
"""
The input file is an xml file. Skipping the pdf to text conversion
and so ignoring requested parsers.
"""
)
args.parser = ["no-op"]
metrics_path = _get_metrics_dir() / f"{args.uid}.json"
Expand All @@ -141,6 +110,7 @@ def _setup(args):
logger.info("Waiting for containers to be ready...")
print("Waiting for containers to be ready...")
wait_for_containers()
print("Containers ready!")
return xml_path, metrics_path


Expand Down
1 change: 1 addition & 0 deletions osm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def parse_args():
parser.add_argument(
"--comment",
required=False,
default="",
help="Comment to provide more information about the provided publication.",
)
parser.add_argument(
Expand Down
27 changes: 16 additions & 11 deletions osm/pipeline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,32 @@ def __init__(self, version: str = "0.0.1"):
def run(self, data: Any, **kwargs) -> Any:
pass

def _get_model_fields(self) -> dict[str, Any]:
return {
"name": self.name,
"version": self.version,
}
def _get_orm_fields(self) -> dict[str, Any]:
fields = {}
for fieldname in self.orm_model_class.model_fields.keys():
if hasattr(self, fieldname):
fields[fieldname] = getattr(self, fieldname)

return fields

@property
def name(self) -> str:
return self.__class__.__name__

@property
def orm_model_class(self) -> type:
return schemas.Component

@property
def orm_model(self) -> schemas.Component:
if self._orm_model is None:
self._orm_model = schemas.Component(
**self._get_model_fields(),
)
self._orm_model = self.orm_model_class(
**self._get_orm_fields(),
)
return self._orm_model

def model_dump(self) -> dict[str, Any]:
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
"""Return a dict of the components model."""
return self.orm_model.model_dump()
return self.orm_model.model_dump(*args, **kwargs)


class Savers:
Expand Down
1 change: 0 additions & 1 deletion osm/pipeline/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def run(self, data: str, parser: str = None) -> dict:
metrics[k] = None
return metrics
else:
breakpoint()
logger.error(f"Error: {response.text}")
response.raise_for_status()

Expand Down
6 changes: 6 additions & 0 deletions osm/pipeline/parsers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import requests

from osm.schemas.custom_fields import LongBytes

from .core import Component

SCIENCEBEAM_URL = "http://localhost:8070/api/convert"


class NoopParser(Component):
"""Used if the input is xml and so needs no parsing."""

def run(self, data: bytes) -> str:
self.sample = LongBytes(data)
return data.decode("utf-8")


class ScienceBeamParser(Component):
def run(self, data: bytes) -> str:
self.sample = LongBytes(data)
headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"}
response = requests.post(SCIENCEBEAM_URL, data=data, headers=headers)
if response.status_code == 200:
Expand Down
72 changes: 54 additions & 18 deletions osm/pipeline/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,75 @@
import os
from pathlib import Path

import dill
import requests
from pydantic import ValidationError

from osm import schemas
from osm._utils import get_compute_context_id
from osm._version import __version__
from osm.schemas import Invocation

from .core import Component

logger = logging.getLogger(__name__)


class FileSaver(Component):
"""Basic saver that writes data to a file."""

def run(self, data: str, path: Path):
"""Write data to a file.
Args:
data (str): Some data.
path (Path): A file path.
"""
path.write_text(data)
logger.info(f"Data saved to {path}")


class JSONSaver(Component):
"""Saver that writes JSON data to a file."""

def run(self, data: dict, path: Path):
"""Write output metrics to a JSON file for the user.
Args:
data (dict): Metrics conformant to a schema.
path (Path): An output path for the metrics.
"""
path.write_text(json.dumps(data))
logger.info(f"Metrics saved to {path}")
print(f"Metrics saved to {path}")


class OSMSaver(Component):
"""A class to gather savers to run a pipeline."""

def __init__(self, comment, email, user_defined_id, filename):
"""Upload data to the OSM API.
Args:
comment (str): A comment from the user in inform downstream analysis.
email (str): For users to be contactable for future data curation etc.
user_defined_id (str): pmid, pmcid, doi, or other unique identifier.
filename (str): Name of the file being processed.
"""
super().__init__()
self.compute_context_id = get_compute_context_id()
self.comment = comment
self.email = email
self.user_defined_id = user_defined_id
self.filename = filename

def run(self, file_in: bytes, metrics: dict, components: list):
def run(self, file_in: bytes, metrics: dict, components: list[schemas.Component]):
"""Save the extracted metrics to the OSM API.
Args:
file_in (bytes): Component input.
metrics (dict): Schema conformant metrics.
components (list[schemas.Component]): parsers, extractors, and savers that constitute the pipeline.
"""
osm_api = os.environ.get("OSM_API", "https://osm.pythonaisolutions.com/api")
# Build the payload
payload = {
Expand All @@ -48,26 +82,25 @@ def run(self, file_in: bytes, metrics: dict, components: list):
"work": {
"user_defined_id": self.user_defined_id,
"filename": self.filename,
"file": base64.b64encode(file_in).decode("utf-8"),
"content_hash": hashlib.sha256(file_in).hexdigest(),
},
"client": {
"compute_context_id": self.compute_context_id,
"email": self.email,
},
"metrics": metrics,
"components": [comp.model_dump() for comp in components],
"metrics": schemas.RtransparentMetrics(**metrics),
"components": [comp.orm_model for comp in components],
}
try:
# Validate the payload
validated_data = Invocation(**payload)
validated_data = schemas.Invocation(**payload)
# If validation passes, send POST request to OSM API. ID is not
# serializable but can be excluded and created by the DB. All types
# should be serializable. If they're not then a they should be encoded
# as a string or something like that: base64.b64encode(bytes).decode("utf-8")
# should be serializable. If they're not then they should be encoded
# as a string or something like that: base64.b64encode(bytes).decode("utf-8")
response = requests.put(
f"{osm_api}/upload/",
json=validated_data.model_dump(exclude=["id"]),
json=validated_data.model_dump(mode="json", exclude=["id"]),
)
if response.status_code == 200:
print("Invocation data uploaded successfully")
Expand All @@ -76,18 +109,21 @@ def run(self, file_in: bytes, metrics: dict, components: list):
f"Failed to upload invocation data: \n {response.text}"
)
except (ValidationError, ValueError) as e:
breakpoint()
try:
payload["upload_error"] = str(e)
# Quarantine the failed payload
response = requests.put(f"{osm_api}/quarantine", json=payload)
failure = schemas.Quarantine(
payload=base64.b64encode(dill.dumps(payload)).decode("utf-8"),
error_message=str(e),
).model_dump(mode="json", exclude=["id"])
response = requests.put(f"{osm_api}/quarantine/", json=failure)
response.raise_for_status()
raise e
except requests.exceptions.RequestException as qe:
requests.put(
f"{osm_api}/quarantine",
json={
"upload_error": str(e),
"recovery_error": str(qe),
},
f"{osm_api}/quarantine/",
json=schemas.Quarantine(
error_message=str(e),
recovery_message=str(qe),
).model_dump(mode="json", exclude=["id"]),
)
logger.warning(f"Validation failed: {e}")
raise e
1 change: 1 addition & 0 deletions osm/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .schemas import Client as Client
from .schemas import Component as Component
from .schemas import Invocation as Invocation
from .schemas import Quarantine as Quarantine
from .schemas import RtransparentMetrics as RtransparentMetrics
from .schemas import Work as Work
2 changes: 1 addition & 1 deletion osm/schemas/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class LongStr(LongField[str]):


class LongBytes(LongField[bytes]):
"""A bytes type that displays '...' instead of the full content in logs or tracebacks."""
"""Wrap byte streams to avoid messy output."""

_inner_schema: ClassVar[CoreSchema] = core_schema.bytes_schema()
_error_kind: ClassVar[str] = "bytes_type"
Expand Down
21 changes: 15 additions & 6 deletions osm/schemas/metrics_schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from odmantic import EmbeddedModel
from pydantic import field_validator
from pydantic import field_serializer, field_validator

from osm._utils import coerce_to_string

Expand All @@ -12,11 +12,6 @@
# all_indicators.csv from the rtransparent publication has both but has the following extra fields:
# code_text,com_code,com_data_availibility,com_file_formats,com_general_db,com_github_data,com_specific_db,com_suppl_code,com_supplemental_data,data_text,dataset,eigenfactor_score,field,is_art,is_code_pred,is_data_pred,is_relevant_code,is_relevant_data,jif,n_cite,score,year,
class RtransparentMetrics(EmbeddedModel):
model_config = {
"json_encoders": {
LongStr: lambda v: v.get_value(),
},
}
# Mandatory fields
is_open_code: Optional[bool]
is_open_data: Optional[bool]
Expand Down Expand Up @@ -197,3 +192,17 @@ class RtransparentMetrics(EmbeddedModel):
@field_validator("article")
def fix_string(cls, v):
return coerce_to_string(v)

@field_serializer(
"data_text",
"code_text",
"coi_text",
"fund_text",
"register_text",
"funding_text",
"open_code_statements",
"open_data_category",
"open_data_statements",
)
def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]:
return value.get_value() if value else None
Loading

0 comments on commit 9faa99a

Please sign in to comment.