diff --git a/README.md b/README.md index 2a0eb564..57976355 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ osm -f path/to/pdf-or-xml2 -u uuid2 --user-managed-compose # Contributing +N.B. On Apple silicon you must use emulation i.e. `export DOCKER_DEFAULT_PLATFORM=linux/amd64` + If you wish to contribute to this project you can set up a development environment with the following: ``` diff --git a/osm/_utils.py b/osm/_utils.py index cad55d08..aad09d5a 100644 --- a/osm/_utils.py +++ b/osm/_utils.py @@ -1,6 +1,4 @@ import argparse -import base64 -import hashlib import logging import os import shlex @@ -12,8 +10,6 @@ import pandas as pd import requests -from osm._version import __version__ - DEFAULT_OUTPUT_DIR = "./osm_output" logger = logging.getLogger(__name__) @@ -62,35 +58,6 @@ def get_compute_context_id(): return hash(f"{os.environ.get('HOSTNAME')}_{os.environ.get('USERNAME')}") -def _upload_data(args, file_in, xml, metrics, components): - osm_api = os.environ.get("OSM_API", "http://localhost:80") - - payload = { - "osm_version": __version__, - "user_comment": args.comment, - "work": { - "user_defined_id": args.uid, - "filename": args.file.name, - "file": base64.b64encode(file_in).decode("utf-8"), - "content_hash": hashlib.sha256(file_in).hexdigest(), - }, - "client": { - "compute_context_id": get_compute_context_id(), - "email": args.email, - }, - "metrics": metrics, - "components": components, - } - # Send POST request to OSM API - response = requests.put(f"{osm_api}/upload", json=payload) - - # Check response status code - if response.status_code == 200: - print("Invocation data uploaded successfully") - else: - print(f"Failed to upload invocation data: \n {response.text}") - - def wait_for_containers(): while True: try: @@ -128,8 +95,10 @@ def _setup(args): raise FileExistsError(xml_path) elif args.filepath.name.endswith(".xml"): logger.warning( - """The input file is an xml file. Skipping the pdf to text - conversion and so ignoring requested parsers.""" + """ + The input file is an xml file. Skipping the pdf to text conversion + and so ignoring requested parsers. + """ ) args.parser = ["no-op"] metrics_path = _get_metrics_dir() / f"{args.uid}.json" @@ -141,6 +110,7 @@ def _setup(args): logger.info("Waiting for containers to be ready...") print("Waiting for containers to be ready...") wait_for_containers() + print("Containers ready!") return xml_path, metrics_path diff --git a/osm/cli.py b/osm/cli.py index 3d9a9883..e6107025 100644 --- a/osm/cli.py +++ b/osm/cli.py @@ -53,6 +53,7 @@ def parse_args(): parser.add_argument( "--comment", required=False, + default="", help="Comment to provide more information about the provided publication.", ) parser.add_argument( diff --git a/osm/pipeline/core.py b/osm/pipeline/core.py index 91c549c6..de62bd0c 100644 --- a/osm/pipeline/core.py +++ b/osm/pipeline/core.py @@ -18,27 +18,32 @@ def __init__(self, version: str = "0.0.1"): def run(self, data: Any, **kwargs) -> Any: pass - def _get_model_fields(self) -> dict[str, Any]: - return { - "name": self.name, - "version": self.version, - } + def _get_orm_fields(self) -> dict[str, Any]: + fields = {} + for fieldname in self.orm_model_class.model_fields.keys(): + if hasattr(self, fieldname): + fields[fieldname] = getattr(self, fieldname) + + return fields @property def name(self) -> str: return self.__class__.__name__ + @property + def orm_model_class(self) -> type: + return schemas.Component + @property def orm_model(self) -> schemas.Component: - if self._orm_model is None: - self._orm_model = schemas.Component( - **self._get_model_fields(), - ) + self._orm_model = self.orm_model_class( + **self._get_orm_fields(), + ) return self._orm_model - def model_dump(self) -> dict[str, Any]: + def model_dump(self, *args, **kwargs) -> dict[str, Any]: """Return a dict of the components model.""" - return self.orm_model.model_dump() + return self.orm_model.model_dump(*args, **kwargs) class Savers: diff --git a/osm/pipeline/extractors.py b/osm/pipeline/extractors.py index b36f0717..33cdea48 100644 --- a/osm/pipeline/extractors.py +++ b/osm/pipeline/extractors.py @@ -26,7 +26,6 @@ def run(self, data: str, parser: str = None) -> dict: metrics[k] = None return metrics else: - breakpoint() logger.error(f"Error: {response.text}") response.raise_for_status() diff --git a/osm/pipeline/parsers.py b/osm/pipeline/parsers.py index 75136ec4..8a96e588 100644 --- a/osm/pipeline/parsers.py +++ b/osm/pipeline/parsers.py @@ -1,17 +1,23 @@ import requests +from osm.schemas.custom_fields import LongBytes + from .core import Component SCIENCEBEAM_URL = "http://localhost:8070/api/convert" class NoopParser(Component): + """Used if the input is xml and so needs no parsing.""" + def run(self, data: bytes) -> str: + self.sample = LongBytes(data) return data.decode("utf-8") class ScienceBeamParser(Component): def run(self, data: bytes) -> str: + self.sample = LongBytes(data) headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"} response = requests.post(SCIENCEBEAM_URL, data=data, headers=headers) if response.status_code == 200: diff --git a/osm/pipeline/savers.py b/osm/pipeline/savers.py index 3bf68ce6..943696f9 100644 --- a/osm/pipeline/savers.py +++ b/osm/pipeline/savers.py @@ -5,12 +5,13 @@ import os from pathlib import Path +import dill import requests from pydantic import ValidationError +from osm import schemas from osm._utils import get_compute_context_id from osm._version import __version__ -from osm.schemas import Invocation from .core import Component @@ -18,20 +19,46 @@ class FileSaver(Component): + """Basic saver that writes data to a file.""" + def run(self, data: str, path: Path): + """Write data to a file. + + Args: + data (str): Some data. + path (Path): A file path. + """ path.write_text(data) logger.info(f"Data saved to {path}") class JSONSaver(Component): + """Saver that writes JSON data to a file.""" + def run(self, data: dict, path: Path): + """Write output metrics to a JSON file for the user. + + Args: + data (dict): Metrics conformant to a schema. + path (Path): An output path for the metrics. + """ path.write_text(json.dumps(data)) logger.info(f"Metrics saved to {path}") print(f"Metrics saved to {path}") class OSMSaver(Component): + """A class to gather savers to run a pipeline.""" + def __init__(self, comment, email, user_defined_id, filename): + """Upload data to the OSM API. + + Args: + comment (str): A comment from the user in inform downstream analysis. + email (str): For users to be contactable for future data curation etc. + user_defined_id (str): pmid, pmcid, doi, or other unique identifier. + filename (str): Name of the file being processed. + """ super().__init__() self.compute_context_id = get_compute_context_id() self.comment = comment @@ -39,7 +66,14 @@ def __init__(self, comment, email, user_defined_id, filename): self.user_defined_id = user_defined_id self.filename = filename - def run(self, file_in: bytes, metrics: dict, components: list): + def run(self, file_in: bytes, metrics: dict, components: list[schemas.Component]): + """Save the extracted metrics to the OSM API. + + Args: + file_in (bytes): Component input. + metrics (dict): Schema conformant metrics. + components (list[schemas.Component]): parsers, extractors, and savers that constitute the pipeline. + """ osm_api = os.environ.get("OSM_API", "https://osm.pythonaisolutions.com/api") # Build the payload payload = { @@ -48,26 +82,25 @@ def run(self, file_in: bytes, metrics: dict, components: list): "work": { "user_defined_id": self.user_defined_id, "filename": self.filename, - "file": base64.b64encode(file_in).decode("utf-8"), "content_hash": hashlib.sha256(file_in).hexdigest(), }, "client": { "compute_context_id": self.compute_context_id, "email": self.email, }, - "metrics": metrics, - "components": [comp.model_dump() for comp in components], + "metrics": schemas.RtransparentMetrics(**metrics), + "components": [comp.orm_model for comp in components], } try: # Validate the payload - validated_data = Invocation(**payload) + validated_data = schemas.Invocation(**payload) # If validation passes, send POST request to OSM API. ID is not # serializable but can be excluded and created by the DB. All types - # should be serializable. If they're not then a they should be encoded - # as a string or something like that: base64.b64encode(bytes).decode("utf-8") + # should be serializable. If they're not then they should be encoded + # as a string or something like that: base64.b64encode(bytes).decode("utf-8") response = requests.put( f"{osm_api}/upload/", - json=validated_data.model_dump(exclude=["id"]), + json=validated_data.model_dump(mode="json", exclude=["id"]), ) if response.status_code == 200: print("Invocation data uploaded successfully") @@ -76,18 +109,21 @@ def run(self, file_in: bytes, metrics: dict, components: list): f"Failed to upload invocation data: \n {response.text}" ) except (ValidationError, ValueError) as e: - breakpoint() try: - payload["upload_error"] = str(e) # Quarantine the failed payload - response = requests.put(f"{osm_api}/quarantine", json=payload) + failure = schemas.Quarantine( + payload=base64.b64encode(dill.dumps(payload)).decode("utf-8"), + error_message=str(e), + ).model_dump(mode="json", exclude=["id"]) + response = requests.put(f"{osm_api}/quarantine/", json=failure) response.raise_for_status() + raise e except requests.exceptions.RequestException as qe: requests.put( - f"{osm_api}/quarantine", - json={ - "upload_error": str(e), - "recovery_error": str(qe), - }, + f"{osm_api}/quarantine/", + json=schemas.Quarantine( + error_message=str(e), + recovery_message=str(qe), + ).model_dump(mode="json", exclude=["id"]), ) - logger.warning(f"Validation failed: {e}") + raise e diff --git a/osm/schemas/__init__.py b/osm/schemas/__init__.py index 5f4eae77..c00cca22 100644 --- a/osm/schemas/__init__.py +++ b/osm/schemas/__init__.py @@ -1,5 +1,6 @@ from .schemas import Client as Client from .schemas import Component as Component from .schemas import Invocation as Invocation +from .schemas import Quarantine as Quarantine from .schemas import RtransparentMetrics as RtransparentMetrics from .schemas import Work as Work diff --git a/osm/schemas/custom_fields.py b/osm/schemas/custom_fields.py index 75c3bea0..36c381d9 100644 --- a/osm/schemas/custom_fields.py +++ b/osm/schemas/custom_fields.py @@ -86,7 +86,7 @@ class LongStr(LongField[str]): class LongBytes(LongField[bytes]): - """A bytes type that displays '...' instead of the full content in logs or tracebacks.""" + """Wrap byte streams to avoid messy output.""" _inner_schema: ClassVar[CoreSchema] = core_schema.bytes_schema() _error_kind: ClassVar[str] = "bytes_type" diff --git a/osm/schemas/metrics_schemas.py b/osm/schemas/metrics_schemas.py index f7a5a5d6..4958abea 100644 --- a/osm/schemas/metrics_schemas.py +++ b/osm/schemas/metrics_schemas.py @@ -1,7 +1,7 @@ from typing import Optional from odmantic import EmbeddedModel -from pydantic import field_validator +from pydantic import field_serializer, field_validator from osm._utils import coerce_to_string @@ -12,11 +12,6 @@ # all_indicators.csv from the rtransparent publication has both but has the following extra fields: # code_text,com_code,com_data_availibility,com_file_formats,com_general_db,com_github_data,com_specific_db,com_suppl_code,com_supplemental_data,data_text,dataset,eigenfactor_score,field,is_art,is_code_pred,is_data_pred,is_relevant_code,is_relevant_data,jif,n_cite,score,year, class RtransparentMetrics(EmbeddedModel): - model_config = { - "json_encoders": { - LongStr: lambda v: v.get_value(), - }, - } # Mandatory fields is_open_code: Optional[bool] is_open_data: Optional[bool] @@ -197,3 +192,17 @@ class RtransparentMetrics(EmbeddedModel): @field_validator("article") def fix_string(cls, v): return coerce_to_string(v) + + @field_serializer( + "data_text", + "code_text", + "coi_text", + "fund_text", + "register_text", + "funding_text", + "open_code_statements", + "open_data_category", + "open_data_statements", + ) + def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]: + return value.get_value() if value else None diff --git a/osm/schemas/schemas.py b/osm/schemas/schemas.py index cd75e83f..7f42251f 100644 --- a/osm/schemas/schemas.py +++ b/osm/schemas/schemas.py @@ -4,7 +4,7 @@ import pandas as pd from odmantic import EmbeddedModel, Field, Model -from pydantic import EmailStr, field_validator +from pydantic import EmailStr, field_serializer, field_validator from osm._utils import coerce_to_string @@ -15,9 +15,6 @@ class Component(EmbeddedModel): model_config = { "extra": "forbid", - "json_encoders": { - LongBytes: lambda v: base64.b64encode(v.get_value()).decode("utf-8"), - }, } name: str version: str @@ -28,6 +25,10 @@ class Component(EmbeddedModel): json_schema_extra={"exclude": True, "select": False, "write_only": True}, ) + @field_serializer("sample") + def serialize_longbytes(self, value: Optional[LongBytes]) -> Optional[str]: + return base64.b64encode(value.get_value()).decode("utf-8") if value else None + class Client(EmbeddedModel): model_config = {"extra": "forbid"} @@ -81,3 +82,9 @@ class Invocation(Model): created_at: datetime.datetime = Field( default_factory=lambda: datetime.datetime.now(datetime.UTC) ) + + +class Quarantine(Model): + payload: str = "" + error_message: str + recovery_message: Optional[str] = None diff --git a/pyproject.toml b/pyproject.toml index 078de165..05ddf36a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ keywords = [ ] dynamic = ["version"] dependencies = [ + "dill", "pandas", "pyarrow", "pydantic", diff --git a/web/api/main.py b/web/api/main.py index 37729516..95073d0c 100644 --- a/web/api/main.py +++ b/web/api/main.py @@ -25,7 +25,7 @@ from fastapi import FastAPI, HTTPException from odmantic import AIOEngine, ObjectId -from osm.schemas import Invocation +from osm.schemas import Invocation, Quarantine app = FastAPI() dburi = os.environ.get("MONGODB_URI", "mongodb://localhost:27017") @@ -39,6 +39,12 @@ async def upload_invocation(invocation: Invocation): return invocation +@app.put("/quarantine/", response_model=Quarantine) +async def upload_failed_invocation(quarantine: Quarantine): + await engine.save(quarantine) + return quarantine + + @app.get("/invocations/{id}", response_model=Invocation) async def get_invocation_by_id(id: ObjectId): invocation = await engine.find_one(Invocation, Invocation.id == id) diff --git a/web/dashboard/app.py b/web/dashboard/app.py index 1677fe65..cd8af205 100644 --- a/web/dashboard/app.py +++ b/web/dashboard/app.py @@ -5,6 +5,7 @@ import param import pyarrow as pa import pyarrow.dataset as ds +import pyarrow.parquet as pq import ui from main_dashboard import MainDashboard from pyarrow import compute as pc @@ -18,6 +19,7 @@ def load_data(): dset = ds.dataset(local_path, format="parquet") else: dset = ds.dataset(osh.matches_to_table(osh.get_data_from_mongo())) + pq.write_table(dset.to_table(), local_path, compression="snappy") tb = dset.to_table() split_col = pc.split_pattern( @@ -39,10 +41,8 @@ def load_data(): # necessary conversion to tuples, which is hashable type # needed for grouping - raw_data.affiliation_country = raw_data.affiliation_country.apply( - lambda cntry: tuple(cntry) - ) - raw_data.funder = raw_data.funder.apply(lambda fndrs: tuple(fndrs)) + for col in ["affiliation_country", "funder", "data_tags"]: + raw_data[col] = raw_data[col].apply(lambda x: tuple(x)) return raw_data diff --git a/web/dashboard/main_dashboard.py b/web/dashboard/main_dashboard.py index 0c49331d..b36afbd1 100644 --- a/web/dashboard/main_dashboard.py +++ b/web/dashboard/main_dashboard.py @@ -25,6 +25,7 @@ "journal", "affiliation_country", "funder", + "data_tags", ], } } @@ -62,7 +63,7 @@ class MainDashboard(param.Parameterized): """ # High-level parameters. - extraction_tool = param.Selector(default="", objects=[], label="Extraction tool") + extraction_tool = param.Selector(default="", objects=[], label="Metrics group") metrics = param.Selector(default=[], objects=[], label="Metrics") @@ -82,6 +83,11 @@ class MainDashboard(param.Parameterized): filter_funder = param.ListSelector(default=[], objects=[], label="Funder") + filter_tags = param.ListSelector(default=[], objects=[], label="Tags") + + # Internal mechanisms + trigger_rendering = param.Integer(default=0) + # UI elements echarts_pane = pn.pane.ECharts( {}, height=640, width=960, renderer="svg", options={"replaceMerge": ["series"]} @@ -120,6 +126,13 @@ def __init__(self, datasets, **params): options: self.new_picker_title("funders", select_picker, values, options), ) + self.tags_select_picker = SelectPicker.from_param( + self.param.filter_tags, + update_title_callback=lambda select_picker, + values, + options: self.new_picker_title("tags", select_picker, values, options), + ) + self.build_pubdate_filter() @pn.depends("extraction_tool", watch=True) @@ -166,7 +179,9 @@ def did_change_extraction_tool(self): self.param.filter_journal.objects = self.raw_data.journal.unique() ## affiliation country - countries_with_count = self.get_countries_with_count() + countries_with_count = self.get_col_values_with_count( + "affiliation_country", lambda x: x is None + ) def country_sorter(c): return countries_with_count[c] @@ -176,7 +191,9 @@ def country_sorter(c): ) ## funder - funders_with_count = self.get_funders_with_count() + funders_with_count = self.get_col_values_with_count( + "funder", lambda x: len(x) == 0 or len(x) == 1 and x[0] == "" + ) def funder_sorter(c): return funders_with_count[c] @@ -185,33 +202,33 @@ def funder_sorter(c): funders_with_count.keys(), key=funder_sorter, reverse=True ) + ## Tags + tags_with_count = self.get_col_values_with_count( + "data_tags", lambda x: x is None + ) + + def tags_sorter(c): + return tags_with_count[c] + + self.param.filter_tags.objects = sorted( + tags_with_count.keys(), key=tags_sorter, reverse=True + ) + # This triggers function "did_change_splitting_var" # which updates filter_journal, filter_affiliation_country and filter_funder self.splitting_var = self.param.splitting_var.objects[0] @lru_cache - def get_funders_with_count(self): - funders = {} - for row in self.raw_data.funder.values: - if len(row) == 0 or len(row) == 1 and row[0] == "": + def get_col_values_with_count(self, col, none_test): + values = {} + for row in self.raw_data[col].values: + if none_test(row): ## Keeping "None" as a string on purpose, to represent it in the SelectPicker - funders["None"] = funders.get("None", 0) + 1 + values["None"] = values.get("None", 0) + 1 else: for c in row: - funders[c] = funders.get(c, 0) + 1 - return funders - - @lru_cache - def get_countries_with_count(self): - countries = {} - for row in self.raw_data.affiliation_country.values: - if row is None: - ## Keeping "None" as a string on purpose, to represent it in the SelectPicker - countries["None"] = countries.get("None", 0) + 1 - else: - for c in row: - countries[c] = countries.get(c, 0) + 1 - return countries + values[c] = values.get(c, 0) + 1 + return values @pn.depends("splitting_var", watch=True) def did_change_splitting_var(self): @@ -235,7 +252,9 @@ def did_change_splitting_var(self): if self.splitting_var == "affiliation_country": # We want to show all countries, but pre-select only the top 10 - countries_with_count = self.get_countries_with_count() + countries_with_count = self.get_col_values_with_count( + "affiliation_country", lambda x: x is None + ) # pre-filter the countries because there are a lot countries_with_count = { @@ -264,7 +283,9 @@ def did_change_splitting_var(self): if self.splitting_var == "funder": # We want to show all funders, but pre-select only the top 10 - funders_with_count = self.get_funders_with_count() + funders_with_count = self.get_col_values_with_count( + "funder", lambda x: len(x) == 0 or len(x) == 1 and x[0] == "" + ) top_5_min = sorted( [ @@ -284,15 +305,24 @@ def did_change_splitting_var(self): else: selected_funders = self.param.filter_funder.objects + # There is currently only two tags, so no need to pre-select a top subset + selected_tags = self.param.filter_tags.objects + # Trigger a batch update of the filters value, # preventing from re-rendering the dashboard several times # and preventing intermediate states where the dashboard renders onces # with all funders for instance, and then restricting on the selected funders. + # Also, we increment the trigger_rendering to force the update of the echarts plot. + # This is usefull when switching from splitting var "None" to "data_tags" for instance. + # In this case, the selected tags don't change, and the plot won't update, hence the need + # for trigger_rendering. print("TRIGGER UPDATE") self.param.update( filter_journal=selected_journals, filter_affiliation_country=selected_countries, filter_funder=selected_funders, + filter_tags=selected_tags, + trigger_rendering=self.trigger_rendering + 1, ) if self.splitting_var == "None": @@ -342,6 +372,17 @@ def funder_filter(cell): filtered_df = filtered_df[filtered_df.funder.apply(funder_filter)] + if len(filtered_df) > 0 and len(self.filter_tags) != len( + self.param.filter_tags.objects + ): + # the filter on tags is similar to the filter on countries + def tags_filter(cell): + if cell is None: + return "None" in self.filter_tags + return any(c in self.filter_tags for c in cell) + + filtered_df = filtered_df[filtered_df.data_tags.apply(tags_filter)] + aggretations = {} for field, aggs in dims_aggregations.items(): for agg in aggs: @@ -353,6 +394,8 @@ def funder_filter(cell): result = filtered_df.groupby(groupers).agg(**aggretations).reset_index() + print("FILTERED_GROUPED_DATA_DONE", len(result)) + return result @pn.depends( @@ -361,6 +404,8 @@ def funder_filter(cell): "filter_affiliation_country", "filter_journal", "filter_funder", + "filter_tags", + "trigger_rendering", watch=True, ) def updated_echart_plot(self): @@ -407,6 +452,12 @@ def updated_echart_plot(self): splitting_var_filter = self.filter_funder splitting_var_column = "funder" splitting_var_query = lambda cell, selected_item: selected_item in cell + + elif self.splitting_var == "data_tags": + splitting_var_filter = self.filter_tags + splitting_var_column = "data_tags" + splitting_var_query = lambda cell, selected_item: selected_item in cell + else: print("Defaulting to splitting var 'journal' ") splitting_var_filter = self.filter_journal @@ -594,6 +645,7 @@ def get_sidebar(self): self.journal_select_picker, self.affiliation_country_select_picker, self.funder_select_picker, + self.tags_select_picker, ] sidebar = pn.Column(*items)