Skip to content

Commit

Permalink
extend upload scripts for new data
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 27, 2024
1 parent 4c37857 commit a12873a
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 40 deletions.
1 change: 1 addition & 0 deletions osm/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .schemas import Client as Client
from .schemas import Component as Component
from .schemas import Invocation as Invocation
from .schemas import ManualAnnotationNIMHDSST as ManualAnnotationNIMHDSST
from .schemas import PayloadError as PayloadError
from .schemas import Quarantine as Quarantine
from .schemas import RtransparentMetrics as RtransparentMetrics
Expand Down
12 changes: 12 additions & 0 deletions osm/schemas/metrics_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,15 @@ def fix_string(cls, v):
)
def serialize_longstr(self, value: Optional[LongStr]) -> Optional[str]:
return value.get_value() if value else None


class ManualAnnotationNIMHDSST(EmbeddedModel):
pmid: Optional[int]
DOI: Optional[str]
Alternative_link: Optional[str]
rtransparent_is_open_data: Optional[bool]
rtransparent_open_data_statements: Optional[str]
manual_is_open_data: Optional[bool]
manual_data_statements: Optional[str]
Notes: Optional[LongStr]
PMID_raw: Optional[int]
63 changes: 37 additions & 26 deletions osm/schemas/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,17 @@

from osm import __version__, schemas
from osm._utils import flatten_dict
from osm.schemas import Client, Invocation, RtransparentMetrics, Work
from osm.schemas import Client, Invocation, Work

logger = logging.getLogger(__name__)


def irp_data_processing(row):
return row


def rtransparent_pub_data_processing(row):
row["is_open_code"] = row.pop("is_code_pred")
row["is_open_data"] = row.pop("is_data_pred")
return row


def theneuro_data_processing(row):
return row


def types_mapper(pa_type):
if pa.types.is_int64(pa_type):
# Map pyarrow int64 to pandas Int64 (nullable integer)
Expand Down Expand Up @@ -74,15 +66,17 @@ def odmantic_to_pyarrow(schema_dict):
return pa.schema(fields)


def get_pyarrow_schema(metrics_type="RtransparentMetrics"):
odmantic_schema_json = getattr(schemas, metrics_type).model_json_schema()
def get_pyarrow_schema(metrics_schema="RtransparentMetrics"):
odmantic_schema_json = getattr(schemas, metrics_schema).model_json_schema()
pyarrow_schema = odmantic_to_pyarrow(odmantic_schema_json)
return pyarrow_schema


def get_table_with_schema(df, other_fields=None, raise_error=True):
def get_table_with_schema(
df, other_fields=None, raise_error=True, metrics_schema="RtransparentMetrics"
):
other_fields = other_fields or []
pyarrow_schema = get_pyarrow_schema()
pyarrow_schema = get_pyarrow_schema(metrics_schema)
adjusted_schema = adjust_schema_to_dataframe(
pyarrow_schema, df, other_fields=other_fields
)
Expand Down Expand Up @@ -114,7 +108,9 @@ def adjust_schema_to_dataframe(schema, df, other_fields: list = None):
return pa.schema(fields)


def transform_data(table: pa.Table, raise_error=True, **kwargs) -> Iterator[dict]:
def transform_data(
table: pa.Table, *, metrics_schema, raise_error=True, **kwargs
) -> Iterator[dict]:
"""
Process each row in a PyArrow Table in a memory-efficient way and yield JSON payloads.
"""
Expand All @@ -124,7 +120,7 @@ def transform_data(table: pa.Table, raise_error=True, **kwargs) -> Iterator[dict
for row in batch.to_pylist():
try:
# Process each row using the existing get_invocation logic
invocation = get_invocation(row, **kwargs)
invocation = get_invocation(row, metrics_schema, **kwargs)
yield invocation.model_dump(mode="json", exclude="id")
except (KeyError, ValidationError) as e:
if raise_error:
Expand All @@ -134,11 +130,11 @@ def transform_data(table: pa.Table, raise_error=True, **kwargs) -> Iterator[dict
continue


def get_invocation(row, **kwargs):
def get_invocation(row, metrics_schema, **kwargs):
kwargs["data_tags"] = kwargs.get("data_tags") or []
if kwargs.get("custom_processing") is not None:
# hack to run custom processing functions from this module
func = globals()[kwargs["custom_processing"]]
func = globals().get(kwargs["custom_processing"], lambda x: x)
row = func(row)
work = Work(
user_defined_id=row.get("pmid"),
Expand All @@ -153,9 +149,8 @@ def get_invocation(row, **kwargs):
compute_context_id=kwargs.get("compute_context_id", 999),
email=kwargs.get("email"),
)
metrics = RtransparentMetrics(**row)
invocation = Invocation(
metrics=metrics,
metrics=metrics_schema(**row),
osm_version=kwargs.get("osm_version", __version__),
client=client,
work=work,
Expand All @@ -175,22 +170,23 @@ def get_data_from_mongo(aggregation: list[dict] | None = None) -> Iterator[dict]
if aggregation is None:
aggregation = [
{
"$match": {
"data_tags": "bulk_upload",
# "work.pmid": {"$regex":r"^2"},
# "metrics.year": {"$gt": 2000},
# "metrics.is_data_pred": {"$eq": True},
},
"$match": {},
},
{
"$project": {
# "osm_version": True,
"funder": True,
"data_tags": True,
"doi": True,
"metrics_group": True,
"work.pmid": True,
"metrics.year": True,
"metrics.is_open_data": True,
"metrics.is_open_code": True,
"metrics.manual_is_open_code": True,
"metrics.rtransparent_is_open_code": True,
"metrics.manual_is_open_data": True,
"metrics.rtransparent_is_open_data": True,
"metrics.affiliation_country": True,
"metrics.journal": True,
"created_at": True,
Expand Down Expand Up @@ -248,6 +244,10 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl

# Convert batch of dicts to DataFrame
df = pd.DataFrame(batch)
if "created_at" in df.columns and df["created_at"].dtype == "O":
df["created_at"] = pd.to_datetime(
df["created_at"], utc=True, format="ISO8601"
)

# Drop the `_id` column if it exists
if "_id" in df.columns:
Expand All @@ -262,7 +262,18 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl
# Extend schema to include any additional columns in the DataFrame
extra_columns = [col for col in df.columns if col not in adjusted_schema.names]
for col in extra_columns:
inferred_type = infer_type_for_column(df[col])
if col == "funder":
inferred_type = pa.list_(pa.string())
elif col == "data_tags":
inferred_type = pa.list_(pa.string())
elif col == "affiliation_country":
inferred_type = pa.list_(pa.string())
elif col == "rtransparent_is_open_data":
inferred_type = pa.bool
elif col == "manual_is_open_data":
inferred_type = pa.bool_()
else:
inferred_type = infer_type_for_column(df[col])
adjusted_schema = adjusted_schema.append(pa.field(col, inferred_type))

# Convert DataFrame to PyArrow Table with the extended schema
Expand Down
30 changes: 23 additions & 7 deletions osm/schemas/schemas.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import base64
import datetime
from typing import Optional
from typing import Optional, Union

import pandas as pd
from odmantic import EmbeddedModel, Field, Model
from pydantic import EmailStr, field_serializer, field_validator
from pydantic import EmailStr, field_serializer, field_validator, model_validator

from osm._utils import coerce_to_string

from .custom_fields import LongBytes
from .metrics_schemas import RtransparentMetrics
from .metrics_schemas import ManualAnnotationNIMHDSST, RtransparentMetrics


class Component(EmbeddedModel):
Expand Down Expand Up @@ -71,7 +71,8 @@ class Invocation(Model):
"""

model_config = {"extra": "forbid"}
metrics: RtransparentMetrics
metrics: Union[RtransparentMetrics | ManualAnnotationNIMHDSST]
metrics_group: str
components: Optional[list[Component]] = []
work: Work
client: Client
Expand All @@ -80,20 +81,35 @@ class Invocation(Model):
funder: Optional[list[str]] = []
data_tags: list[str] = []
created_at: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
default_factory=lambda: datetime.datetime.now(datetime.UTC).replace(
microsecond=0
)
)

@model_validator(mode="before")
def set_metrics_group(cls, values):
metrics = values.get("metrics")
if isinstance(metrics, (RtransparentMetrics, ManualAnnotationNIMHDSST)):
values["metrics_group"] = metrics.__class__.__name__
else:
raise ValueError("Unknown metrics type")
return values


class Quarantine(Model):
payload: bytes = b""
error_message: str
created_at: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
default_factory=lambda: datetime.datetime.now(datetime.UTC).replace(
microsecond=0
)
)


class PayloadError(Model):
error_message: str
created_at: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
default_factory=lambda: datetime.datetime.now(datetime.UTC).replace(
microsecond=0
)
)
25 changes: 20 additions & 5 deletions scripts/invocation_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import pyarrow.dataset as ds
import pymongo

from osm import schemas
from osm.schemas import Component, schema_helpers
from osm.schemas.schema_helpers import transform_data

logger = logging.getLogger(__name__)

DB_NAME = os.environ["DB_NAME"]
MONGODB_URI = os.environ["MONGODB_URI"]

Expand All @@ -24,10 +27,16 @@
"components": [Component(name="Sciencebeam parser/RTransparent", version="x.x.x")],
}
theneuro_kwargs = {
"data_tags": ["Th Neuro"],
"data_tags": ["The Neuro"],
"user_comment": "Bulk upload of The Neuro data containing OddPub metrics underlying RTransparent metrics for open code/data.",
"components": [Component(name="TheNeuroOddPub", version="x.x.x")],
}
manual_scoring_kwargs = {
"data_tags": ["Manual Annotation NIMH/DSST"],
"user_comment": "Bulk upload of some manually scored open code/data along with RTransparent extracted equivalents.",
"components": [Component(name="ManualAnnotation-NIMHDSST", version="x.x.x")],
"metrics_schema": schemas.ManualAnnotationNIMHDSST,
}
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
Expand Down Expand Up @@ -69,15 +78,18 @@ def get_data(args):

def get_upload_kwargs(args):
if args.custom_processing:
assert hasattr(
schema_helpers, args.custom_processing
), f"Custom processing function {args.custom_processing} not found"
if not hasattr(schema_helpers, args.custom_processing):
logger.warning(
f"Custom processing function {args.custom_processing} not found"
)
if args.custom_processing == "rtransparent_pub_data_processing":
kwargs = rtrans_publication_kwargs
elif args.custom_processing == "irp_data_processing":
kwargs = irp_kwargs
elif args.custom_processing == "theneuro_data_processing":
kwargs = theneuro_kwargs
elif args.custom_processing == "manual_scoring_data_processing":
kwargs = manual_scoring_kwargs
else:
raise ValueError(
f"Kwargs associated with {args.custom_processing} not found"
Expand All @@ -95,10 +107,13 @@ def main():
args = parse_args()
tb = get_data(args)
upload_kwargs = get_upload_kwargs(args)
schema = upload_kwargs.pop("metrics_schema", schemas.RtransparentMetrics)

try:
db = pymongo.MongoClient(MONGODB_URI).osm
db.invocation.insert_many(transform_data(tb, **upload_kwargs), ordered=False)
db.invocation.insert_many(
transform_data(tb, metrics_schema=schema, **upload_kwargs), ordered=False
)
except Exception as e:
logger.error(f"Failed to process data: {e}")
raise e
Expand Down
11 changes: 9 additions & 2 deletions scripts/merge_funder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def get_user_args():
parser.add_argument("dataset_path", help="Path to the dataset file")
parser.add_argument("merge_col", default="pmcid_pmc", help="Column to merge on")
parser.add_argument(
"--funder_path",
"--funder-path",
help="Path to the funders file",
default="tempdata/funders.feather",
)
parser.add_argument(
"--metrics-schema",
help="Name of the schema class to use in order to validate the data",
default="RTransparentMetrics",
)
return parser.parse_args()


Expand Down Expand Up @@ -81,7 +86,9 @@ def main():

print("Converting to pyarrow...")
funder_field = pa.field("funder", pa.list_(pa.string()), nullable=True)
tb = get_table_with_schema(dataset.assign(funder=None), [funder_field])
tb = get_table_with_schema(
dataset.assign(funder=None), [funder_field], metrics_schema=args.metrics_schema
)

print("Merging with funders...")
merge_funder(funder, tb, args.merge_col)
Expand Down

0 comments on commit a12873a

Please sign in to comment.