diff --git a/osm/schemas/schema_helpers.py b/osm/schemas/schema_helpers.py index 94634c83..05534d57 100644 --- a/osm/schemas/schema_helpers.py +++ b/osm/schemas/schema_helpers.py @@ -5,13 +5,13 @@ import pandas as pd import pyarrow as pa -from odmantic import SyncEngine +from odmantic import EmbeddedModel, SyncEngine from pydantic import ValidationError from pymongo import MongoClient from osm import __version__, schemas from osm._utils import flatten_dict -from osm.schemas import Client, Invocation, Work +from osm.schemas import Client, Invocation, Work, metrics_schemas logger = logging.getLogger(__name__) @@ -66,17 +66,17 @@ def odmantic_to_pyarrow(schema_dict): return pa.schema(fields) -def get_pyarrow_schema(metrics_schema="RtransparentMetrics"): - odmantic_schema_json = getattr(schemas, metrics_schema).model_json_schema() +def get_pyarrow_schema(schema_name="RtransparentMetrics"): + odmantic_schema_json = getattr(schemas, schema_name).model_json_schema() pyarrow_schema = odmantic_to_pyarrow(odmantic_schema_json) return pyarrow_schema def get_table_with_schema( - df, other_fields=None, raise_error=True, metrics_schema="RtransparentMetrics" + df, other_fields=None, raise_error=True, schema_name="RtransparentMetrics" ): other_fields = other_fields or [] - pyarrow_schema = get_pyarrow_schema(metrics_schema) + pyarrow_schema = get_pyarrow_schema(schema_name) adjusted_schema = adjust_schema_to_dataframe( pyarrow_schema, df, other_fields=other_fields ) @@ -166,11 +166,32 @@ def get_invocation(row, metrics_schema, **kwargs): return invocation -def get_data_from_mongo(aggregation: list[dict] | None = None) -> Iterator[dict]: +def get_metrics_schemas(): + schemas = {} + for attrib_name in dir(metrics_schemas): + attrib = getattr(metrics_schemas, attrib_name) + if isinstance(attrib, type) and issubclass(attrib, EmbeddedModel): + schemas[attrib_name] = getattr(metrics_schemas, attrib_name) + return schemas + + +def get_all_data_from_mongo(aggregation: list[dict] | None = None) -> dict[pa.Table]: + mschemas = get_metrics_schemas() + tables = {} + for sname in mschemas.keys(): + matches = get_data_from_mongo_for_schema(schema_name=sname) + table = matches_to_table(matches, schema_name=sname) + tables[sname] = table + return tables + + +def get_data_from_mongo_for_schema( + aggregation: list[dict] | None = None, schema_name="RtransparentMetrics" +) -> Iterator[dict]: if aggregation is None: aggregation = [ { - "$match": {}, + "$match": {"metrics_group": schema_name}, }, { "$project": { @@ -228,7 +249,9 @@ def infer_type_for_column(column): return pa.string() -def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Table: +def matches_to_table( + matches: Iterator[dict], batch_size: int = 1000, schema_name="RtransparentMetrics" +) -> pa.Table: # Initialize an empty list to store batches of tables tables = [] @@ -254,12 +277,18 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl df = df.drop(columns=["_id"]) # Adjust schema according to pre-existing functionality - pyarrow_schema = get_pyarrow_schema() # Get the base schema + pyarrow_schema = get_pyarrow_schema( + schema_name=schema_name + ) # Get the base schema adjusted_schema = adjust_schema_to_dataframe( pyarrow_schema, df ) # Adjust schema to match DataFrame # Extend schema to include any additional columns in the DataFrame + # though ideally returned queries should conform pretty well to the + # schema so this is for entries outside of that. We may wish to consider + # searching for fields across every schema to find a type to eliminate + # the guesswork here. extra_columns = [col for col in df.columns if col not in adjusted_schema.names] for col in extra_columns: if col == "funder": @@ -269,7 +298,7 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl elif col == "affiliation_country": inferred_type = pa.list_(pa.string()) elif col == "rtransparent_is_open_data": - inferred_type = pa.bool + inferred_type = pa.bool_() elif col == "manual_is_open_data": inferred_type = pa.bool_() elif col == "created_at": diff --git a/scripts/merge_funder.py b/scripts/merge_funder.py index 4c2b3e91..4d37c36b 100644 --- a/scripts/merge_funder.py +++ b/scripts/merge_funder.py @@ -53,7 +53,7 @@ def get_user_args(): default="tempdata/funders.feather", ) parser.add_argument( - "--metrics-schema", + "--schema-name", help="Name of the schema class to use in order to validate the data", default="RTransparentMetrics", ) @@ -87,7 +87,7 @@ def main(): print("Converting to pyarrow...") funder_field = pa.field("funder", pa.list_(pa.string()), nullable=True) tb = get_table_with_schema( - dataset.assign(funder=None), [funder_field], metrics_schema=args.metrics_schema + dataset.assign(funder=None), [funder_field], schema_name=args.schema_name ) print("Merging with funders...")