From eddd42679f9209ec0fb79a04b33826af93e19c35 Mon Sep 17 00:00:00 2001 From: leej3 Date: Thu, 29 Aug 2024 10:57:00 +0100 Subject: [PATCH] tmep --- osm/_utils.py | 28 +++++++++++++++++++++++++++- osm/cli.py | 5 ++++- osm/pipeline/core.py | 2 +- osm/pipeline/extractors.py | 3 ++- osm/pipeline/parsers.py | 5 ++++- osm/pipeline/savers.py | 4 +++- osm/schemas/__init__.py | 1 + osm/schemas/schema_helpers.py | 12 +++++------- scripts/merge_funder.py | 4 +++- 9 files changed, 50 insertions(+), 14 deletions(-) diff --git a/osm/_utils.py b/osm/_utils.py index 7dcacab1..55603468 100644 --- a/osm/_utils.py +++ b/osm/_utils.py @@ -100,6 +100,32 @@ def compose_down(): print(f"Logs of docker containers are saved at {docker_log}") +def make_uid_path_safe(uid: str) -> str: + """ + Sanitizes a given string to make it safe for use as a file path. + + Args: + - uid (str): The original string that needs to be sanitized. + + Returns: + - str: A sanitized string safe for use as a file path. + """ + # Define a regex pattern to match unsafe characters for file paths + unsafe_characters_pattern = r'[\/\\:*?"<>|]' + + # Replace unsafe characters with an underscore + safe_uid = re.sub(unsafe_characters_pattern, "_", uid) + + # Remove leading/trailing whitespace + safe_uid = safe_uid.strip() + + # Replace multiple consecutive spaces or underscores with a single underscore + safe_uid = re.sub(r"[\s_]+", "_", safe_uid) + + # Return the sanitized UID + return safe_uid + + def _setup(args): output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -115,7 +141,7 @@ def _setup(args): """ ) args.parser = ["no-op"] - metrics_path = _get_metrics_dir() / f"{args.uid}.json" + metrics_path = _get_metrics_dir() / f"{make_uid_path_safe(args.uid)}.json" if metrics_path.exists(): raise FileExistsError(metrics_path) # create logs directory if necessary diff --git a/osm/cli.py b/osm/cli.py index bc193d06..be8f55bf 100644 --- a/osm/cli.py +++ b/osm/cli.py @@ -99,7 +99,10 @@ def main(): ), ), ) - pipeline.run(user_managed_compose=args.user_managed_compose) + pipeline.run( + user_managed_compose=args.user_managed_compose, + llm_model=args.llm_model, + ) finally: if not args.user_managed_compose: compose_down() diff --git a/osm/pipeline/core.py b/osm/pipeline/core.py index 952efb15..eb2daf25 100644 --- a/osm/pipeline/core.py +++ b/osm/pipeline/core.py @@ -113,7 +113,7 @@ def run(self, user_managed_compose: bool = False, llm_model: str = None): self.savers.save_osm( data=self.file_data, metrics=extracted_metrics, - components=[*self.parsers, *self.extractors, *self.savers], + components=[parser, extractor, *self.savers], ) self.savers.save_json(extracted_metrics, self.metrics_path) diff --git a/osm/pipeline/extractors.py b/osm/pipeline/extractors.py index 686f4b49..218f501f 100644 --- a/osm/pipeline/extractors.py +++ b/osm/pipeline/extractors.py @@ -11,7 +11,8 @@ class RTransparentExtractor(Component): - def _run(self, data: bytes, parser: str = None, **kwargs) -> dict: + def _run(self, data: bytes, **kwargs) -> dict: + parser = kwargs["parser"] self.sample = LongBytes(data) # Prepare the file to be sent as a part of form data diff --git a/osm/pipeline/parsers.py b/osm/pipeline/parsers.py index 8dc59d60..91eb92dc 100644 --- a/osm/pipeline/parsers.py +++ b/osm/pipeline/parsers.py @@ -30,7 +30,10 @@ class PMCParser(NoopParser): class ScienceBeamParser(Component): - def _run(self, data: bytes, user_managed_compose=False, **kwargs) -> str: + def _run(self, data: bytes, user_managed_compose: bool = False, **kwargs) -> str: + user_managed_compose = user_managed_compose or kwargs.get( + "user_managed_compose", False + ) self.sample = LongBytes(data) headers = {"Accept": "application/tei+xml", "Content-Type": "application/pdf"} files = {"file": ("input.pdf", io.BytesIO(data), "application/pdf")} diff --git a/osm/pipeline/savers.py b/osm/pipeline/savers.py index 50a320c4..9ab9c490 100644 --- a/osm/pipeline/savers.py +++ b/osm/pipeline/savers.py @@ -14,6 +14,7 @@ from osm import schemas from osm._utils import get_compute_context_id from osm._version import __version__ +from osm.schemas.schema_helpers import get_metrics_schemas from .core import Component @@ -83,6 +84,8 @@ def _run(self, data: bytes, metrics: dict, components: list[schemas.Component]): osm_api = os.environ.get("OSM_API", "https://opensciencemetrics.org/api") print(f"Using OSM API: {osm_api}") # Build the payload + schemas = get_metrics_schemas() + breakpoint() try: payload = { "osm_version": __version__, @@ -109,7 +112,6 @@ def _run(self, data: bytes, metrics: dict, components: list[schemas.Component]): raise e try: # Validate the payload - breakpoint() validated_data = schemas.Invocation(**payload) # If validation passes, send POST request to OSM API. ID is not # serializable but can be excluded and created by the DB. All types diff --git a/osm/schemas/__init__.py b/osm/schemas/__init__.py index 4672e320..b9f7da7b 100644 --- a/osm/schemas/__init__.py +++ b/osm/schemas/__init__.py @@ -1,6 +1,7 @@ from .schemas import Client as Client from .schemas import Component as Component from .schemas import Invocation as Invocation +from .schemas import LLMExtractorMetrics as LLMExtractorMetrics from .schemas import ManualAnnotationNIMHDSST as ManualAnnotationNIMHDSST from .schemas import PayloadError as PayloadError from .schemas import Quarantine as Quarantine diff --git a/osm/schemas/schema_helpers.py b/osm/schemas/schema_helpers.py index 1d940050..5ef532ae 100644 --- a/osm/schemas/schema_helpers.py +++ b/osm/schemas/schema_helpers.py @@ -66,15 +66,13 @@ def odmantic_to_pyarrow(schema_dict): return pa.schema(fields) -def get_pyarrow_schema(schema_name="RtransparentMetrics"): +def get_pyarrow_schema(schema_name): odmantic_schema_json = getattr(schemas, schema_name).model_json_schema() pyarrow_schema = odmantic_to_pyarrow(odmantic_schema_json) return pyarrow_schema -def get_table_with_schema( - df, other_fields=None, raise_error=True, schema_name="RtransparentMetrics" -): +def get_table_with_schema(df, *, schema_name, other_fields=None, raise_error=True): other_fields = other_fields or [] pyarrow_schema = get_pyarrow_schema(schema_name) adjusted_schema = adjust_schema_to_dataframe( @@ -181,14 +179,14 @@ def get_all_data_from_mongo(aggregation: list[dict] | None = None) -> dict[pa.Ta mschemas = get_metrics_schemas() tables = {} for sname in mschemas.keys(): - matches = get_data_from_mongo_for_schema(schema_name=sname) + matches = get_data_from_mongo_for_schema(sname) table = matches_to_table(matches, schema_name=sname) tables[sname] = table return tables def get_data_from_mongo_for_schema( - aggregation: list[dict] | None = None, schema_name="RtransparentMetrics" + schema_name, aggregation: list[dict] | None = None ) -> Iterator[dict]: if aggregation is None: aggregation = [ @@ -252,7 +250,7 @@ def infer_type_for_column(column): def matches_to_table( - matches: Iterator[dict], batch_size: int = 1000, schema_name="RtransparentMetrics" + matches: Iterator[dict], *, schema_name, batch_size: int = 1000 ) -> pa.Table: # Initialize an empty list to store batches of tables tables = [] diff --git a/scripts/merge_funder.py b/scripts/merge_funder.py index 4d37c36b..1bf01138 100644 --- a/scripts/merge_funder.py +++ b/scripts/merge_funder.py @@ -87,7 +87,9 @@ def main(): print("Converting to pyarrow...") funder_field = pa.field("funder", pa.list_(pa.string()), nullable=True) tb = get_table_with_schema( - dataset.assign(funder=None), [funder_field], schema_name=args.schema_name + dataset.assign(funder=None), + schema_name=args.schema_name, + other_fields=[funder_field], ) print("Merging with funders...")