Skip to content

Commit

Permalink
switch to bulk upload with generator
Browse files Browse the repository at this point in the history
  • Loading branch information
leej3 committed Aug 9, 2024
1 parent e8a88b5 commit 03069b8
Showing 1 changed file with 23 additions and 121 deletions.
144 changes: 23 additions & 121 deletions scripts/invocation_upload.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
import asyncio
import logging
import os
import pickle
import tempfile
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import requests
from motor.motor_asyncio import AsyncIOMotorClient
from odmantic import AIOEngine
import pymongo
from pydantic import ValidationError

from osm.schemas import Client, Invocation, Work

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

DB_NAME = os.environ["DB_NAME"]
MONGO_URI = os.environ["MONGO_URI"]
ERROR_CSV_PATH = "error_log.csv"
ERROR_LOG_PATH = "error.log"
# NOTICE: output of rt_all without corresponding values in the all_indicators.csv from Rtransparent publication
unmapped = {
"article": "",
"is_relevant": None,
"is_explicit": None,
}

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


def transform_data(df) -> List[Invocation]:
"""Handles data transformation as well as mapping"""
invocations = []
"""Convert the dataframe to a list of Invocation objects"""
for index, row in df.iterrows():
try:
work = Work(
Expand All @@ -53,34 +49,11 @@ def transform_data(df) -> List[Invocation]:
work=work,
user_comment="Initial database seeding with data from Rtransparent publication",
)
yield invocation.dict()

invocations.append(invocation)
except (KeyError, ValidationError) as e:
if isinstance(e, KeyError):
raise KeyError(f"Error key not found in row {index}: {e}")
elif isinstance(e, ValidationError):
raise e

return invocations


def read_data(data_path: str):
"""Checks to see if url is a path or https to download or read file"""
try:
if data_path.startswith("https"):
csv = download_csv(data_path)
df = pd.read_csv(csv)
elif data_path.endswith("csv"):
df = pd.read_csv(data_path)
else:
df = pd.read_feather(data_path)
return df
except FileNotFoundError as e:
raise e


ERROR_CSV_PATH = "error_log.csv"
ERROR_LOG_PATH = "error.log"
logger.error(f"Error processing row {index}")
write_error_to_file(invocation, e)


def write_error_to_file(invocation: Invocation, error: Exception):
Expand All @@ -99,90 +72,19 @@ def write_error_to_file(invocation: Invocation, error: Exception):
log_file.write(f"Error processing invocation: {invocation}\nError: {error}\n\n")


async def upload_data(
invocations: List[Invocation], mongo_uri: str, db_name: str, count
):
"""upload invocations to MongoDB one after the other to prevent timeout"""
motor_client = AsyncIOMotorClient(mongo_uri)
engine = AIOEngine(client=motor_client, database=db_name)
def main():
df = pd.read_feather("all_indicators.feather", dtype_backend="pyarrow")
if df.empty:
raise Exception("Dataframe is empty")
try:
await engine.save_all(invocations)
logger.info(f"Upload successful {count}")
except (TypeError, Exception) as e:
breakpoint()
logger.error(f"Error uploading batch: {e} {count}")
for inv in invocations:
try:
await engine.save(inv)
except Exception as e:
write_error_to_file(inv, e)
raise e
finally:
motor_client.close()


def download_csv(url):
"""downloads file and store in a temp location"""
try:
response = requests.get(url)
if response.status_code == 200:
temp_file, temp_file_path = tempfile.mkstemp(suffix=".csv")
os.close(temp_file) # Close the file descriptor
with open(temp_file_path, "wb") as file:
file.write(response.content)
return temp_file_path
else:
raise Exception(
f"Failed to download CSV. Status code: {response.status_code}"
)
except Exception as e:
raise e


def break_into_chunks(dataframe):
batch_size = 1000 # Define your batch size
number_of_batches = int(np.ceil(len(dataframe) / batch_size))

# Use np.array_split to create batches
chunks = np.array_split(dataframe, number_of_batches)
return chunks


def main(db_url: str, db_name: str, data_path="all_indicators.feather"):
try:
transformed_pickle = Path("invocations.pkl")
if transformed_pickle.exists():
df = pickle.loads(transformed_pickle.read_bytes())
else:
breakpoint()
df = read_data(data_path)
if not df.empty:
chunks = break_into_chunks(df)
ind = 1
for chunk in chunks:
invocations = transform_data(chunk)

logger.info(f"Uploading data to {db_url} {ind}/{len(chunks)}")
asyncio.run(
upload_data(
invocations, db_url, db_name, f"{ind}/{len(chunks)}"
)
)
ind = ind + 1

# transformed_pickle.write_bytes(pickle.dumps(invocations))
else:
raise Exception("Dataframe is empty")
db = pymongo.MongoClient(MONGO_URI).osm
db.invocation.insert_many(transform_data(df))
except Exception as e:
breakpoint()
logger.error(f"Failed to process data: {e}")
raise e
# raise e
breakpoint()


if __name__ == "__main__":
url = os.getenv(
"DATABASE_URL",
None,
)
name = os.getenv("DATABASE_NAME", None)
main(url, name)
main()

0 comments on commit 03069b8

Please sign in to comment.