Skip to content

Commit

Permalink
tidy and document cli functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 26, 2024
1 parent ee3544d commit d7d2250
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 77 deletions.
40 changes: 5 additions & 35 deletions osm/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import argparse
import base64
import hashlib
import logging
import os
import shlex
Expand All @@ -12,8 +10,6 @@
import pandas as pd
import requests

from osm._version import __version__

DEFAULT_OUTPUT_DIR = "./osm_output"
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,35 +58,6 @@ def get_compute_context_id():
return hash(f"{os.environ.get('HOSTNAME')}_{os.environ.get('USERNAME')}")


def _upload_data(args, file_in, xml, metrics, components):
osm_api = os.environ.get("OSM_API", "http://localhost:80")

payload = {
"osm_version": __version__,
"user_comment": args.comment,
"work": {
"user_defined_id": args.uid,
"filename": args.file.name,
"file": base64.b64encode(file_in).decode("utf-8"),
"content_hash": hashlib.sha256(file_in).hexdigest(),
},
"client": {
"compute_context_id": get_compute_context_id(),
"email": args.email,
},
"metrics": metrics,
"components": components,
}
# Send POST request to OSM API
response = requests.put(f"{osm_api}/upload", json=payload)

# Check response status code
if response.status_code == 200:
print("Invocation data uploaded successfully")
else:
print(f"Failed to upload invocation data: \n {response.text}")


def wait_for_containers():
while True:
try:
Expand Down Expand Up @@ -128,8 +95,10 @@ def _setup(args):
raise FileExistsError(xml_path)
elif args.filepath.name.endswith(".xml"):
logger.warning(
"""The input file is an xml file. Skipping the pdf to text
conversion and so ignoring requested parsers."""
"""
The input file is an xml file. Skipping the pdf to text conversion
and so ignoring requested parsers.
"""
)
args.parser = ["no-op"]
metrics_path = _get_metrics_dir() / f"{args.uid}.json"
Expand All @@ -141,6 +110,7 @@ def _setup(args):
logger.info("Waiting for containers to be ready...")
print("Waiting for containers to be ready...")
wait_for_containers()
print("Containers ready!")
return xml_path, metrics_path


Expand Down
1 change: 1 addition & 0 deletions osm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def parse_args():
parser.add_argument(
"--comment",
required=False,
default="",
help="Comment to provide more information about the provided publication.",
)
parser.add_argument(
Expand Down
27 changes: 16 additions & 11 deletions osm/pipeline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,32 @@ def __init__(self, version: str = "0.0.1"):
def run(self, data: Any, **kwargs) -> Any:
pass

def _get_model_fields(self) -> dict[str, Any]:
return {
"name": self.name,
"version": self.version,
}
def _get_orm_fields(self) -> dict[str, Any]:
fields = {}
for fieldname in self.orm_model_class.model_fields.keys():
if hasattr(self, fieldname):
fields[fieldname] = getattr(self, fieldname)

return fields

@property
def name(self) -> str:
return self.__class__.__name__

@property
def orm_model_class(self) -> type:
return schemas.Component

@property
def orm_model(self) -> schemas.Component:
if self._orm_model is None:
self._orm_model = schemas.Component(
**self._get_model_fields(),
)
self._orm_model = self.orm_model_class(
**self._get_orm_fields(),
)
return self._orm_model

def model_dump(self) -> dict[str, Any]:
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
"""Return a dict of the components model."""
return self.orm_model.model_dump()
return self.orm_model.model_dump(*args, **kwargs)


class Savers:
Expand Down
1 change: 0 additions & 1 deletion osm/pipeline/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def run(self, data: str, parser: str = None) -> dict:
metrics[k] = None
return metrics
else:
breakpoint()
logger.error(f"Error: {response.text}")
response.raise_for_status()

Expand Down
6 changes: 6 additions & 0 deletions osm/pipeline/parsers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import requests

from osm.schemas.custom_fields import LongBytes

from .core import Component

SCIENCEBEAM_URL = "http://localhost:8070/api/convert"


class NoopParser(Component):
"""Used if the input is xml and so needs no parsing."""

def run(self, data: bytes) -> str:
self.sample = LongBytes(data)
return data.decode("utf-8")


class ScienceBeamParser(Component):
def run(self, data: bytes) -> str:
self.sample = LongBytes(data)
headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"}
response = requests.post(SCIENCEBEAM_URL, data=data, headers=headers)
if response.status_code == 200:
Expand Down
72 changes: 54 additions & 18 deletions osm/pipeline/savers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,75 @@
import os
from pathlib import Path

import dill
import requests
from pydantic import ValidationError

from osm import schemas
from osm._utils import get_compute_context_id
from osm._version import __version__
from osm.schemas import Invocation

from .core import Component

logger = logging.getLogger(__name__)


class FileSaver(Component):
"""Basic saver that writes data to a file."""

def run(self, data: str, path: Path):
"""Write data to a file.
Args:
data (str): Some data.
path (Path): A file path.
"""
path.write_text(data)
logger.info(f"Data saved to {path}")


class JSONSaver(Component):
"""Saver that writes JSON data to a file."""

def run(self, data: dict, path: Path):
"""Write output metrics to a JSON file for the user.
Args:
data (dict): Metrics conformant to a schema.
path (Path): An output path for the metrics.
"""
path.write_text(json.dumps(data))
logger.info(f"Metrics saved to {path}")
print(f"Metrics saved to {path}")


class OSMSaver(Component):
"""A class to gather savers to run a pipeline."""

def __init__(self, comment, email, user_defined_id, filename):
"""Upload data to the OSM API.
Args:
comment (str): A comment from the user in inform downstream analysis.
email (str): For users to be contactable for future data curation etc.
user_defined_id (str): pmid, pmcid, doi, or other unique identifier.
filename (str): Name of the file being processed.
"""
super().__init__()
self.compute_context_id = get_compute_context_id()
self.comment = comment
self.email = email
self.user_defined_id = user_defined_id
self.filename = filename

def run(self, file_in: bytes, metrics: dict, components: list):
def run(self, file_in: bytes, metrics: dict, components: list[schemas.Component]):
"""Save the extracted metrics to the OSM API.
Args:
file_in (bytes): Component input.
metrics (dict): Schema conformant metrics.
components (list[schemas.Component]): parsers, extractors, and savers that constitute the pipeline.
"""
osm_api = os.environ.get("OSM_API", "https://osm.pythonaisolutions.com/api")
# Build the payload
payload = {
Expand All @@ -48,26 +82,25 @@ def run(self, file_in: bytes, metrics: dict, components: list):
"work": {
"user_defined_id": self.user_defined_id,
"filename": self.filename,
"file": base64.b64encode(file_in).decode("utf-8"),
"content_hash": hashlib.sha256(file_in).hexdigest(),
},
"client": {
"compute_context_id": self.compute_context_id,
"email": self.email,
},
"metrics": metrics,
"components": [comp.model_dump() for comp in components],
"metrics": schemas.RtransparentMetrics(**metrics),
"components": [comp.orm_model for comp in components],
}
try:
# Validate the payload
validated_data = Invocation(**payload)
validated_data = schemas.Invocation(**payload)
# If validation passes, send POST request to OSM API. ID is not
# serializable but can be excluded and created by the DB. All types
# should be serializable. If they're not then a they should be encoded
# as a string or something like that: base64.b64encode(bytes).decode("utf-8")
# should be serializable. If they're not then they should be encoded
# as a string or something like that: base64.b64encode(bytes).decode("utf-8")
response = requests.put(
f"{osm_api}/upload/",
json=validated_data.model_dump(exclude=["id"]),
json=validated_data.model_dump(mode="json", exclude=["id"]),
)
if response.status_code == 200:
print("Invocation data uploaded successfully")
Expand All @@ -76,18 +109,21 @@ def run(self, file_in: bytes, metrics: dict, components: list):
f"Failed to upload invocation data: \n {response.text}"
)
except (ValidationError, ValueError) as e:
breakpoint()
try:
payload["upload_error"] = str(e)
# Quarantine the failed payload
response = requests.put(f"{osm_api}/quarantine", json=payload)
failure = schemas.Quarantine(
payload=base64.b64encode(dill.dumps(payload)).decode("utf-8"),
error_message=str(e),
).model_dump(mode="json", exclude=["id"])
response = requests.put(f"{osm_api}/quarantine/", json=failure)
response.raise_for_status()
raise e
except requests.exceptions.RequestException as qe:
requests.put(
f"{osm_api}/quarantine",
json={
"upload_error": str(e),
"recovery_error": str(qe),
},
f"{osm_api}/quarantine/",
json=schemas.Quarantine(
error_message=str(e),
recovery_message=str(qe),
).model_dump(mode="json", exclude=["id"]),
)
logger.warning(f"Validation failed: {e}")
raise e
1 change: 1 addition & 0 deletions osm/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .schemas import Client as Client
from .schemas import Component as Component
from .schemas import Invocation as Invocation
from .schemas import Quarantine as Quarantine
from .schemas import RtransparentMetrics as RtransparentMetrics
from .schemas import Work as Work
2 changes: 1 addition & 1 deletion osm/schemas/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class LongStr(LongField[str]):


class LongBytes(LongField[bytes]):
"""A bytes type that displays '...' instead of the full content in logs or tracebacks."""
"""Wrap byte streams to avoid messy output."""

_inner_schema: ClassVar[CoreSchema] = core_schema.bytes_schema()
_error_kind: ClassVar[str] = "bytes_type"
Expand Down
21 changes: 15 additions & 6 deletions osm/schemas/metrics_schemas.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from odmantic import EmbeddedModel
from pydantic import field_validator
from pydantic import field_serializer, field_validator

from osm._utils import coerce_to_string

Expand All @@ -12,11 +12,6 @@
# all_indicators.csv from the rtransparent publication has both but has the following extra fields:
# code_text,com_code,com_data_availibility,com_file_formats,com_general_db,com_github_data,com_specific_db,com_suppl_code,com_supplemental_data,data_text,dataset,eigenfactor_score,field,is_art,is_code_pred,is_data_pred,is_relevant_code,is_relevant_data,jif,n_cite,score,year,
class RtransparentMetrics(EmbeddedModel):
model_config = {
"json_encoders": {
LongStr: lambda v: v.get_value(),
},
}
# Mandatory fields
is_open_code: Optional[bool]
is_open_data: Optional[bool]
Expand Down Expand Up @@ -197,3 +192,17 @@ class RtransparentMetrics(EmbeddedModel):
@field_validator("article")
def fix_string(cls, v):
return coerce_to_string(v)

@field_serializer(
"data_text",
"code_text",
"coi_text",
"fund_text",
"register_text",
"funding_text",
"open_code_statements",
"open_data_category",
"open_data_statements",
)
def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]:
return value.get_value() if value else None
15 changes: 11 additions & 4 deletions osm/schemas/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd
from odmantic import EmbeddedModel, Field, Model
from pydantic import EmailStr, field_validator
from pydantic import EmailStr, field_serializer, field_validator

from osm._utils import coerce_to_string

Expand All @@ -15,9 +15,6 @@
class Component(EmbeddedModel):
model_config = {
"extra": "forbid",
"json_encoders": {
LongBytes: lambda v: base64.b64encode(v.get_value()).decode("utf-8"),
},
}
name: str
version: str
Expand All @@ -28,6 +25,10 @@ class Component(EmbeddedModel):
json_schema_extra={"exclude": True, "select": False, "write_only": True},
)

@field_serializer("sample")
def serialize_longbytes(self, value: Optional[LongBytes]) -> Optional[str]:
return base64.b64encode(value.get_value()).decode("utf-8") if value else None


class Client(EmbeddedModel):
model_config = {"extra": "forbid"}
Expand Down Expand Up @@ -81,3 +82,9 @@ class Invocation(Model):
created_at: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
)


class Quarantine(Model):
payload: str = ""
error_message: str
recovery_message: Optional[str] = None
Loading

0 comments on commit d7d2250

Please sign in to comment.