diff --git a/osm/schemas/__init__.py b/osm/schemas/__init__.py index eeb4eaa9..4672e320 100644 --- a/osm/schemas/__init__.py +++ b/osm/schemas/__init__.py @@ -1,6 +1,7 @@ from .schemas import Client as Client from .schemas import Component as Component from .schemas import Invocation as Invocation +from .schemas import ManualAnnotationNIMHDSST as ManualAnnotationNIMHDSST from .schemas import PayloadError as PayloadError from .schemas import Quarantine as Quarantine from .schemas import RtransparentMetrics as RtransparentMetrics diff --git a/osm/schemas/metrics_schemas.py b/osm/schemas/metrics_schemas.py index 4958abea..01004313 100644 --- a/osm/schemas/metrics_schemas.py +++ b/osm/schemas/metrics_schemas.py @@ -206,3 +206,15 @@ def fix_string(cls, v): ) def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]: return value.get_value() if value else None + + +class ManualAnnotationNIMHDSST(EmbeddedModel): + pmid: Optional[int] + DOI: Optional[str] + Alternative_link: Optional[str] + rtransparent_is_open_data: Optional[bool] + rtransparent_open_data_statements: Optional[str] + manual_is_open_data: Optional[bool] + manual_data_statements: Optional[str] + Notes: Optional[LongStr] + PMID_raw: Optional[int] diff --git a/osm/schemas/schema_helpers.py b/osm/schemas/schema_helpers.py index 5ea35f05..1223cea3 100644 --- a/osm/schemas/schema_helpers.py +++ b/osm/schemas/schema_helpers.py @@ -11,25 +11,17 @@ from osm import __version__, schemas from osm._utils import flatten_dict -from osm.schemas import Client, Invocation, RtransparentMetrics, Work +from osm.schemas import Client, Invocation, Work logger = logging.getLogger(__name__) -def irp_data_processing(row): - return row - - def rtransparent_pub_data_processing(row): row["is_open_code"] = row.pop("is_code_pred") row["is_open_data"] = row.pop("is_data_pred") return row -def theneuro_data_processing(row): - return row - - def types_mapper(pa_type): if pa.types.is_int64(pa_type): # Map pyarrow int64 to pandas Int64 (nullable integer) @@ -74,15 +66,17 @@ def odmantic_to_pyarrow(schema_dict): return pa.schema(fields) -def get_pyarrow_schema(metrics_type="RtransparentMetrics"): - odmantic_schema_json = getattr(schemas, metrics_type).model_json_schema() +def get_pyarrow_schema(metrics_schema="RtransparentMetrics"): + odmantic_schema_json = getattr(schemas, metrics_schema).model_json_schema() pyarrow_schema = odmantic_to_pyarrow(odmantic_schema_json) return pyarrow_schema -def get_table_with_schema(df, other_fields=None, raise_error=True): +def get_table_with_schema( + df, other_fields=None, raise_error=True, metrics_schema="RtransparentMetrics" +): other_fields = other_fields or [] - pyarrow_schema = get_pyarrow_schema() + pyarrow_schema = get_pyarrow_schema(metrics_schema) adjusted_schema = adjust_schema_to_dataframe( pyarrow_schema, df, other_fields=other_fields ) @@ -114,7 +108,9 @@ def adjust_schema_to_dataframe(schema, df, other_fields: list = None): return pa.schema(fields) -def transform_data(table: pa.Table, raise_error=True, **kwargs) -> Iterator[dict]: +def transform_data( + table: pa.Table, *, metrics_schema, raise_error=True, **kwargs +) -> Iterator[dict]: """ Process each row in a PyArrow Table in a memory-efficient way and yield JSON payloads. """ @@ -124,7 +120,7 @@ def transform_data(table: pa.Table, raise_error=True, **kwargs) -> Iterator[dict for row in batch.to_pylist(): try: # Process each row using the existing get_invocation logic - invocation = get_invocation(row, **kwargs) + invocation = get_invocation(row, metrics_schema, **kwargs) yield invocation.model_dump(mode="json", exclude="id") except (KeyError, ValidationError) as e: if raise_error: @@ -134,11 +130,11 @@ def transform_data(table: pa.Table, raise_error=True, **kwargs) -> Iterator[dict continue -def get_invocation(row, **kwargs): +def get_invocation(row, metrics_schema, **kwargs): kwargs["data_tags"] = kwargs.get("data_tags") or [] if kwargs.get("custom_processing") is not None: # hack to run custom processing functions from this module - func = globals()[kwargs["custom_processing"]] + func = globals().get(kwargs["custom_processing"], lambda x: x) row = func(row) work = Work( user_defined_id=row.get("pmid"), @@ -153,9 +149,8 @@ def get_invocation(row, **kwargs): compute_context_id=kwargs.get("compute_context_id", 999), email=kwargs.get("email"), ) - metrics = RtransparentMetrics(**row) invocation = Invocation( - metrics=metrics, + metrics=metrics_schema(**row), osm_version=kwargs.get("osm_version", __version__), client=client, work=work, @@ -175,22 +170,23 @@ def get_data_from_mongo(aggregation: list[dict] | None = None) -> Iterator[dict] if aggregation is None: aggregation = [ { - "$match": { - "data_tags": "bulk_upload", - # "work.pmid": {"$regex":r"^2"}, - # "metrics.year": {"$gt": 2000}, - # "metrics.is_data_pred": {"$eq": True}, - }, + "$match": {}, }, { "$project": { # "osm_version": True, "funder": True, "data_tags": True, + "doi": True, + "metrics_group": True, "work.pmid": True, "metrics.year": True, "metrics.is_open_data": True, "metrics.is_open_code": True, + "metrics.manual_is_open_code": True, + "metrics.rtransparent_is_open_code": True, + "metrics.manual_is_open_data": True, + "metrics.rtransparent_is_open_data": True, "metrics.affiliation_country": True, "metrics.journal": True, "created_at": True, @@ -248,6 +244,10 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl # Convert batch of dicts to DataFrame df = pd.DataFrame(batch) + if "created_at" in df.columns and df["created_at"].dtype == "O": + df["created_at"] = pd.to_datetime( + df["created_at"], utc=True, format="ISO8601" + ) # Drop the `_id` column if it exists if "_id" in df.columns: @@ -262,7 +262,18 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl # Extend schema to include any additional columns in the DataFrame extra_columns = [col for col in df.columns if col not in adjusted_schema.names] for col in extra_columns: - inferred_type = infer_type_for_column(df[col]) + if col == "funder": + inferred_type = pa.list_(pa.string()) + elif col == "data_tags": + inferred_type = pa.list_(pa.string()) + elif col == "affiliation_country": + inferred_type = pa.list_(pa.string()) + elif col == "rtransparent_is_open_data": + inferred_type = pa.bool + elif col == "manual_is_open_data": + inferred_type = pa.bool_() + else: + inferred_type = infer_type_for_column(df[col]) adjusted_schema = adjusted_schema.append(pa.field(col, inferred_type)) # Convert DataFrame to PyArrow Table with the extended schema diff --git a/osm/schemas/schemas.py b/osm/schemas/schemas.py index 43efbbd1..1406e218 100644 --- a/osm/schemas/schemas.py +++ b/osm/schemas/schemas.py @@ -1,15 +1,15 @@ import base64 import datetime -from typing import Optional +from typing import Optional, Union import pandas as pd from odmantic import EmbeddedModel, Field, Model -from pydantic import EmailStr, field_serializer, field_validator +from pydantic import EmailStr, field_serializer, field_validator, model_validator from osm._utils import coerce_to_string from .custom_fields import LongBytes -from .metrics_schemas import RtransparentMetrics +from .metrics_schemas import ManualAnnotationNIMHDSST, RtransparentMetrics class Component(EmbeddedModel): @@ -71,7 +71,8 @@ class Invocation(Model): """ model_config = {"extra": "forbid"} - metrics: RtransparentMetrics + metrics: Union[RtransparentMetrics | ManualAnnotationNIMHDSST] + metrics_group: str components: Optional[list[Component]] = [] work: Work client: Client @@ -80,20 +81,35 @@ class Invocation(Model): funder: Optional[list[str]] = [] data_tags: list[str] = [] created_at: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.UTC) + default_factory=lambda: datetime.datetime.now(datetime.UTC).replace( + microsecond=0 + ) ) + @model_validator(mode="before") + def set_metrics_group(cls, values): + metrics = values.get("metrics") + if isinstance(metrics, (RtransparentMetrics, ManualAnnotationNIMHDSST)): + values["metrics_group"] = metrics.__class__.__name__ + else: + raise ValueError("Unknown metrics type") + return values + class Quarantine(Model): payload: bytes = b"" error_message: str created_at: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.UTC) + default_factory=lambda: datetime.datetime.now(datetime.UTC).replace( + microsecond=0 + ) ) class PayloadError(Model): error_message: str created_at: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.UTC) + default_factory=lambda: datetime.datetime.now(datetime.UTC).replace( + microsecond=0 + ) ) diff --git a/scripts/invocation_upload.py b/scripts/invocation_upload.py index 5935314e..6da17f8d 100644 --- a/scripts/invocation_upload.py +++ b/scripts/invocation_upload.py @@ -6,9 +6,12 @@ import pyarrow.dataset as ds import pymongo +from osm import schemas from osm.schemas import Component, schema_helpers from osm.schemas.schema_helpers import transform_data +logger = logging.getLogger(__name__) + DB_NAME = os.environ["DB_NAME"] MONGODB_URI = os.environ["MONGODB_URI"] @@ -24,10 +27,16 @@ "components": [Component(name="Sciencebeam parser/RTransparent", version="x.x.x")], } theneuro_kwargs = { - "data_tags": ["Th Neuro"], + "data_tags": ["The Neuro"], "user_comment": "Bulk upload of The Neuro data containing OddPub metrics underlying RTransparent metrics for open code/data.", "components": [Component(name="TheNeuroOddPub", version="x.x.x")], } +manual_scoring_kwargs = { + "data_tags": ["Manual Annotation NIMH/DSST"], + "user_comment": "Bulk upload of some manually scored open code/data along with RTransparent extracted equivalents.", + "components": [Component(name="ManualAnnotation-NIMHDSST", version="x.x.x")], + "metrics_schema": schemas.ManualAnnotationNIMHDSST, +} logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) @@ -69,15 +78,18 @@ def get_data(args): def get_upload_kwargs(args): if args.custom_processing: - assert hasattr( - schema_helpers, args.custom_processing - ), f"Custom processing function {args.custom_processing} not found" + if not hasattr(schema_helpers, args.custom_processing): + logger.warning( + f"Custom processing function {args.custom_processing} not found" + ) if args.custom_processing == "rtransparent_pub_data_processing": kwargs = rtrans_publication_kwargs elif args.custom_processing == "irp_data_processing": kwargs = irp_kwargs elif args.custom_processing == "theneuro_data_processing": kwargs = theneuro_kwargs + elif args.custom_processing == "manual_scoring_data_processing": + kwargs = manual_scoring_kwargs else: raise ValueError( f"Kwargs associated with {args.custom_processing} not found" @@ -95,10 +107,13 @@ def main(): args = parse_args() tb = get_data(args) upload_kwargs = get_upload_kwargs(args) + schema = upload_kwargs.pop("metrics_schema", schemas.RtransparentMetrics) try: db = pymongo.MongoClient(MONGODB_URI).osm - db.invocation.insert_many(transform_data(tb, **upload_kwargs), ordered=False) + db.invocation.insert_many( + transform_data(tb, metrics_schema=schema, **upload_kwargs), ordered=False + ) except Exception as e: logger.error(f"Failed to process data: {e}") raise e diff --git a/scripts/merge_funder.py b/scripts/merge_funder.py index ff42d1c6..4c2b3e91 100644 --- a/scripts/merge_funder.py +++ b/scripts/merge_funder.py @@ -48,10 +48,15 @@ def get_user_args(): parser.add_argument("dataset_path", help="Path to the dataset file") parser.add_argument("merge_col", default="pmcid_pmc", help="Column to merge on") parser.add_argument( - "--funder_path", + "--funder-path", help="Path to the funders file", default="tempdata/funders.feather", ) + parser.add_argument( + "--metrics-schema", + help="Name of the schema class to use in order to validate the data", + default="RTransparentMetrics", + ) return parser.parse_args() @@ -81,7 +86,9 @@ def main(): print("Converting to pyarrow...") funder_field = pa.field("funder", pa.list_(pa.string()), nullable=True) - tb = get_table_with_schema(dataset.assign(funder=None), [funder_field]) + tb = get_table_with_schema( + dataset.assign(funder=None), [funder_field], metrics_schema=args.metrics_schema + ) print("Merging with funders...") merge_funder(funder, tb, args.merge_col)