Skip to content

Commit

Permalink
update scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 23, 2024
1 parent 2f02021 commit f65da5c
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 45 deletions.
127 changes: 82 additions & 45 deletions scripts/invocation_upload.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,126 @@
import argparse
import logging
import os
from typing import List
from pathlib import Path

import pandas as pd
import pymongo
from pydantic import ValidationError

from osm.schemas import Client, Invocation, Work
from osm import schemas

DB_NAME = os.environ["DB_NAME"]
MONGO_URI = os.environ["MONGO_URI"]
ERROR_CSV_PATH = "error_log.csv"
ERROR_LOG_PATH = "error.log"
# NOTICE: output of rt_all without corresponding values in the all_indicators.csv from Rtransparent publication
unmapped = {
"article": "",
"is_relevant": None,
"is_explicit": None,
}
MONGODB_URI = os.environ["MONGODB_URI"]
ERROR_CSV_PATH = Path("error_log.csv")
ERROR_LOG_PATH = Path("error.log")

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
cols_mapping = {}


def transform_data(df) -> List[Invocation]:
def custom_irp_data_processing(row):
# row["pmid"] = row["article"]
return row


def transform_data(df, tags=None, custom_processing=None) -> list[schemas.Invocation]:
"""Convert the dataframe to a list of Invocation objects"""
tags = tags or []
for index, row in df.iterrows():
try:
work = Work(
user_defined_id=str(row["doi"]),
pmid=str(row.pop("pmid")),
doi=str(row.pop("doi")),
if custom_processing is not None:
func = globals()[custom_processing]
row = func(row)

work = schemas.Work(
user_defined_id=row.get("doi") or row.get("pmid"),
pmid=row.get("pmid"),
doi=row.get("doi"),
openalex_id=None,
scopus_id=None,
filename=row.pop("filename"),
file=None,
filename=row.get("filename") or "",
content_hash=None,
)
client = Client(compute_context_id=999, email=None)
client = schemas.Client(compute_context_id=999, email=None)

metrics = {**unmapped, **row.to_dict()}
invocation = Invocation(
metrics = {**row.to_dict()}
invocation = schemas.Invocation(
metrics=metrics,
osm_version="0.0.1",
client=client,
work=work,
user_comment="Initial database seeding with data from Rtransparent publication",
user_comment="Initial database seeding with publications from the NIH IRP",
data_tags=["bulk_upload", *tags],
funder=row.get("funder"),
components=[
schemas.Component(name="Scibeam-parser/Rtransparent", version="NA")
],
)
yield invocation.dict()
yield invocation

except (KeyError, ValidationError) as e:
except (KeyError, ValidationError, Exception) as e:
breakpoint()
logger.error(f"Error processing row {index}")
write_error_to_file(invocation, e)


def write_error_to_file(invocation: Invocation, error: Exception):
with open(ERROR_CSV_PATH, "a") as csv_file, open(ERROR_LOG_PATH, "a") as log_file:
# Write the problematic invocation data to the CSV
row_dict = {
**invocation.metrics,
**invocation.work.dict(),
**invocation.client.dict(),
}
pd.DataFrame([row_dict]).to_csv(
csv_file, header=csv_file.tell() == 0, index=False
write_error_to_file(row, e)


def write_error_to_file(row: pd.Series, error: Exception):
with ERROR_CSV_PATH.open("a") as csv_file, ERROR_LOG_PATH.open("a") as log_file:
# Write the problematic row data to the CSV, add header if not yet populated.
breakpoint()
row.to_csv(
csv_file,
header=not ERROR_CSV_PATH.exists() or ERROR_CSV_PATH.stat().st_size == 0,
index=False,
)

# Log the error details
log_file.write(f"Error processing invocation: {invocation}\nError: {error}\n\n")
# Drop strings values as they tend to be too long
display_row = (
row.apply(lambda x: x if not isinstance(x, str) else None)
.dropna()
.to_dict()
)
log_file.write(f"Error processing data:\n {display_row}\nError: {error}\n\n")


def parse_args():
parser = argparse.ArgumentParser(description="Invocation Upload")
parser.add_argument(
"-i", "--input_file", required=True, help="Path to the input file"
)
parser.add_argument(
"-t",
"--tags",
nargs="+",
help="Tags to apply to the uploaded data for filtering etc.",
)
parser.add_argument(
"-c",
"--custom-processing",
help="Name of function that applies custom processing to the data",
)
return parser.parse_args()


def main():
df = pd.read_feather("all_indicators.feather", dtype_backend="pyarrow")
args = parse_args()
df = pd.read_feather(args.input_file, dtype_backend="pyarrow")
if df.empty:
raise Exception("Dataframe is empty")
try:
db = pymongo.MongoClient(MONGO_URI).osm
db.invocation.insert_many(transform_data(df))
db = pymongo.MongoClient(MONGODB_URI).osm
new_docs = transform_data(
df, tags=args.tags, custom_processing=args.custom_processing
)
db.invocation.insert_many(
(new_doc.model_dump(mode="json", exclude="id") for new_doc in new_docs)
)
except Exception as e:
breakpoint()
logger.error(f"Failed to process data: {e}")
# raise e
breakpoint()
raise e


if __name__ == "__main__":
Expand Down
148 changes: 148 additions & 0 deletions scripts/merge_funder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
from pathlib import Path

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

from osm import schemas


def odmantic_to_pyarrow(schema_dict):
# Type mapping from JSON schema types to pyarrow types
type_mapping = {
"integer": pa.int64(),
"number": pa.float64(),
"string": pa.string(),
"boolean": pa.bool_(),
# For simplicity, map null to string, but this will not be used
"null": pa.string(),
"array": pa.list_(pa.string()), # Assuming array of strings; adjust as needed
"object": pa.struct([]), # Complex types can be handled differently
}

fields = []

for prop, details in schema_dict["properties"].items():
if "anyOf" in details:
# Handle 'anyOf' by selecting the first non-null type
primary_type = next(
(t["type"] for t in details["anyOf"] if t["type"] != "null"), "string"
)
pyarrow_type = type_mapping[primary_type]
nullable = True # If 'anyOf' contains 'null', the field should be nullable
else:
# Directly map the type if 'anyOf' is not present
pyarrow_type = type_mapping.get(details["type"], pa.string())
nullable = False # Assume fields without 'anyOf' are non-nullable

# Create the field with the appropriate nullability
fields.append(pa.field(prop, pyarrow_type, nullable=nullable))

return pa.schema(fields)


def read_parquet_chunks_and_combine(chunk_dir, pyarrow_schema):
chunk_dir = Path(chunk_dir)
all_files = sorted(chunk_dir.glob("*.parquet"))

dfs = []
for file in all_files:
df = pd.read_parquet(file, schema=pyarrow_schema)
dfs.append(df)

combined_df = pd.concat(dfs, ignore_index=True)
return combined_df


def save_combined_df_as_feather(df, output_file):
df.reset_index(drop=True).to_feather(output_file)


def setup():
parser = argparse.ArgumentParser()
parser.add_argument("dataset_path", help="Path to the dataset file")
args = parser.parse_args()
dset_path = Path(args.dataset_path)
dataset = pd.read_feather(dset_path, dtype_backend="pyarrow")
if str(dset_path) == "tempdata/sharestats.feather":
dataset = dataset.rename(columns={"article": "pmid"})

df = pd.read_csv("tempdata/pmid-funding-matrix.csv")
funder_columns = df.columns.difference(["pmid"])
df["funder"] = df[funder_columns].apply(
lambda x: funder_columns[x].tolist(), axis=1
)
funder = df.loc[df["funder"].astype(bool), ["pmid", "funder"]]
return dataset, funder, dset_path


def merge_funder(row, funder):
pmid = row["pmid"]
funder_info = funder[funder["pmid"] == pmid]

if not funder_info.empty:
row["funder"] = funder_info.iloc[0]["funder"]
else:
row["funder"] = []

return row


def subset_schema_to_dataframe(schema, df):
# Filter schema fields to only those present in the DataFrame
fields = [field for field in schema if field.name in df.columns]
return pa.schema(fields)


def main():
odmantic_schema_json = schemas.RtransparentMetrics.model_json_schema()
pyarrow_schema = odmantic_to_pyarrow(odmantic_schema_json)

dataset, funder, dset_path = setup()

adjusted_schema = subset_schema_to_dataframe(pyarrow_schema, dataset)

output_dir = Path(f"tempdata/{dset_path.stem}-chunks")
output_dir.mkdir(parents=True, exist_ok=True)

chunk_index = 0
collected_rows = []

for _, row in dataset.iterrows():
fixed_row = merge_funder(row, funder)
collected_rows.append(fixed_row)

if len(collected_rows) >= 1000:
chunk_file = output_dir / f"chunk_{chunk_index}.parquet"
pq.write_table(
pa.Table.from_pandas(
pd.DataFrame(collected_rows), schema=adjusted_schema
),
chunk_file,
compression="snappy",
)
collected_rows = []
chunk_index += 1

if collected_rows:
chunk_file = output_dir / f"chunk_{chunk_index}.parquet"
try:
pq.write_table(
pa.Table.from_pandas(
pd.DataFrame(collected_rows), schema=adjusted_schema
),
chunk_file,
compression="snappy",
)
except ValueError:
breakpoint()

df_out = read_parquet_chunks_and_combine(output_dir, adjusted_schema)
save_combined_df_as_feather(
df_out, dset_path.parent / f"{dset_path.stem}-with-funder.feather"
)


if __name__ == "__main__":
main()

0 comments on commit f65da5c

Please sign in to comment.