Skip to content

Commit

Permalink
improve mongo query
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 28, 2024
1 parent 159e899 commit 77f3394
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 13 deletions.
51 changes: 40 additions & 11 deletions osm/schemas/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

import pandas as pd
import pyarrow as pa
from odmantic import SyncEngine
from odmantic import EmbeddedModel, SyncEngine
from pydantic import ValidationError
from pymongo import MongoClient

from osm import __version__, schemas
from osm._utils import flatten_dict
from osm.schemas import Client, Invocation, Work
from osm.schemas import Client, Invocation, Work, metrics_schemas

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,17 +66,17 @@ def odmantic_to_pyarrow(schema_dict):
return pa.schema(fields)


def get_pyarrow_schema(metrics_schema="RtransparentMetrics"):
odmantic_schema_json = getattr(schemas, metrics_schema).model_json_schema()
def get_pyarrow_schema(schema_name="RtransparentMetrics"):
odmantic_schema_json = getattr(schemas, schema_name).model_json_schema()
pyarrow_schema = odmantic_to_pyarrow(odmantic_schema_json)
return pyarrow_schema


def get_table_with_schema(
df, other_fields=None, raise_error=True, metrics_schema="RtransparentMetrics"
df, other_fields=None, raise_error=True, schema_name="RtransparentMetrics"
):
other_fields = other_fields or []
pyarrow_schema = get_pyarrow_schema(metrics_schema)
pyarrow_schema = get_pyarrow_schema(schema_name)
adjusted_schema = adjust_schema_to_dataframe(
pyarrow_schema, df, other_fields=other_fields
)
Expand Down Expand Up @@ -166,11 +166,32 @@ def get_invocation(row, metrics_schema, **kwargs):
return invocation


def get_data_from_mongo(aggregation: list[dict] | None = None) -> Iterator[dict]:
def get_metrics_schemas():
schemas = {}
for attrib_name in dir(metrics_schemas):
attrib = getattr(metrics_schemas, attrib_name)
if isinstance(attrib, type) and issubclass(attrib, EmbeddedModel):
schemas[attrib_name] = getattr(metrics_schemas, attrib_name)
return schemas


def get_all_data_from_mongo(aggregation: list[dict] | None = None) -> dict[pa.Table]:
mschemas = get_metrics_schemas()
tables = {}
for sname in mschemas.keys():
matches = get_data_from_mongo_for_schema(schema_name=sname)
table = matches_to_table(matches, schema_name=sname)
tables[sname] = table
return tables


def get_data_from_mongo_for_schema(
aggregation: list[dict] | None = None, schema_name="RtransparentMetrics"
) -> Iterator[dict]:
if aggregation is None:
aggregation = [
{
"$match": {},
"$match": {"metrics_group": schema_name},
},
{
"$project": {
Expand Down Expand Up @@ -228,7 +249,9 @@ def infer_type_for_column(column):
return pa.string()


def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Table:
def matches_to_table(
matches: Iterator[dict], batch_size: int = 1000, schema_name="RtransparentMetrics"
) -> pa.Table:
# Initialize an empty list to store batches of tables
tables = []

Expand All @@ -254,12 +277,18 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl
df = df.drop(columns=["_id"])

# Adjust schema according to pre-existing functionality
pyarrow_schema = get_pyarrow_schema() # Get the base schema
pyarrow_schema = get_pyarrow_schema(
schema_name=schema_name
) # Get the base schema
adjusted_schema = adjust_schema_to_dataframe(
pyarrow_schema, df
) # Adjust schema to match DataFrame

# Extend schema to include any additional columns in the DataFrame
# though ideally returned queries should conform pretty well to the
# schema so this is for entries outside of that. We may wish to consider
# searching for fields across every schema to find a type to eliminate
# the guesswork here.
extra_columns = [col for col in df.columns if col not in adjusted_schema.names]
for col in extra_columns:
if col == "funder":
Expand All @@ -269,7 +298,7 @@ def matches_to_table(matches: Iterator[dict], batch_size: int = 1000) -> pa.Tabl
elif col == "affiliation_country":
inferred_type = pa.list_(pa.string())
elif col == "rtransparent_is_open_data":
inferred_type = pa.bool
inferred_type = pa.bool_()
elif col == "manual_is_open_data":
inferred_type = pa.bool_()
elif col == "created_at":
Expand Down
4 changes: 2 additions & 2 deletions scripts/merge_funder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_user_args():
default="tempdata/funders.feather",
)
parser.add_argument(
"--metrics-schema",
"--schema-name",
help="Name of the schema class to use in order to validate the data",
default="RTransparentMetrics",
)
Expand Down Expand Up @@ -87,7 +87,7 @@ def main():
print("Converting to pyarrow...")
funder_field = pa.field("funder", pa.list_(pa.string()), nullable=True)
tb = get_table_with_schema(
dataset.assign(funder=None), [funder_field], metrics_schema=args.metrics_schema
dataset.assign(funder=None), [funder_field], schema_name=args.schema_name
)

print("Merging with funders...")
Expand Down

0 comments on commit 77f3394

Please sign in to comment.