From 03069b87491b99a039f904ca8adcaee0c9dd1c11 Mon Sep 17 00:00:00 2001 From: leej3 Date: Fri, 9 Aug 2024 16:39:34 +0100 Subject: [PATCH] switch to bulk upload with generator --- scripts/invocation_upload.py | 144 ++++++----------------------------- 1 file changed, 23 insertions(+), 121 deletions(-) diff --git a/scripts/invocation_upload.py b/scripts/invocation_upload.py index 468ce400..780ff696 100644 --- a/scripts/invocation_upload.py +++ b/scripts/invocation_upload.py @@ -1,25 +1,17 @@ -import asyncio import logging import os -import pickle -import tempfile -from pathlib import Path from typing import List -import numpy as np import pandas as pd -import requests -from motor.motor_asyncio import AsyncIOMotorClient -from odmantic import AIOEngine +import pymongo from pydantic import ValidationError from osm.schemas import Client, Invocation, Work -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - +DB_NAME = os.environ["DB_NAME"] +MONGO_URI = os.environ["MONGO_URI"] +ERROR_CSV_PATH = "error_log.csv" +ERROR_LOG_PATH = "error.log" # NOTICE: output of rt_all without corresponding values in the all_indicators.csv from Rtransparent publication unmapped = { "article": "", @@ -27,10 +19,14 @@ "is_explicit": None, } +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + def transform_data(df) -> List[Invocation]: - """Handles data transformation as well as mapping""" - invocations = [] + """Convert the dataframe to a list of Invocation objects""" for index, row in df.iterrows(): try: work = Work( @@ -53,34 +49,11 @@ def transform_data(df) -> List[Invocation]: work=work, user_comment="Initial database seeding with data from Rtransparent publication", ) + yield invocation.dict() - invocations.append(invocation) except (KeyError, ValidationError) as e: - if isinstance(e, KeyError): - raise KeyError(f"Error key not found in row {index}: {e}") - elif isinstance(e, ValidationError): - raise e - - return invocations - - -def read_data(data_path: str): - """Checks to see if url is a path or https to download or read file""" - try: - if data_path.startswith("https"): - csv = download_csv(data_path) - df = pd.read_csv(csv) - elif data_path.endswith("csv"): - df = pd.read_csv(data_path) - else: - df = pd.read_feather(data_path) - return df - except FileNotFoundError as e: - raise e - - -ERROR_CSV_PATH = "error_log.csv" -ERROR_LOG_PATH = "error.log" + logger.error(f"Error processing row {index}") + write_error_to_file(invocation, e) def write_error_to_file(invocation: Invocation, error: Exception): @@ -99,90 +72,19 @@ def write_error_to_file(invocation: Invocation, error: Exception): log_file.write(f"Error processing invocation: {invocation}\nError: {error}\n\n") -async def upload_data( - invocations: List[Invocation], mongo_uri: str, db_name: str, count -): - """upload invocations to MongoDB one after the other to prevent timeout""" - motor_client = AsyncIOMotorClient(mongo_uri) - engine = AIOEngine(client=motor_client, database=db_name) +def main(): + df = pd.read_feather("all_indicators.feather", dtype_backend="pyarrow") + if df.empty: + raise Exception("Dataframe is empty") try: - await engine.save_all(invocations) - logger.info(f"Upload successful {count}") - except (TypeError, Exception) as e: - breakpoint() - logger.error(f"Error uploading batch: {e} {count}") - for inv in invocations: - try: - await engine.save(inv) - except Exception as e: - write_error_to_file(inv, e) - raise e - finally: - motor_client.close() - - -def download_csv(url): - """downloads file and store in a temp location""" - try: - response = requests.get(url) - if response.status_code == 200: - temp_file, temp_file_path = tempfile.mkstemp(suffix=".csv") - os.close(temp_file) # Close the file descriptor - with open(temp_file_path, "wb") as file: - file.write(response.content) - return temp_file_path - else: - raise Exception( - f"Failed to download CSV. Status code: {response.status_code}" - ) - except Exception as e: - raise e - - -def break_into_chunks(dataframe): - batch_size = 1000 # Define your batch size - number_of_batches = int(np.ceil(len(dataframe) / batch_size)) - - # Use np.array_split to create batches - chunks = np.array_split(dataframe, number_of_batches) - return chunks - - -def main(db_url: str, db_name: str, data_path="all_indicators.feather"): - try: - transformed_pickle = Path("invocations.pkl") - if transformed_pickle.exists(): - df = pickle.loads(transformed_pickle.read_bytes()) - else: - breakpoint() - df = read_data(data_path) - if not df.empty: - chunks = break_into_chunks(df) - ind = 1 - for chunk in chunks: - invocations = transform_data(chunk) - - logger.info(f"Uploading data to {db_url} {ind}/{len(chunks)}") - asyncio.run( - upload_data( - invocations, db_url, db_name, f"{ind}/{len(chunks)}" - ) - ) - ind = ind + 1 - - # transformed_pickle.write_bytes(pickle.dumps(invocations)) - else: - raise Exception("Dataframe is empty") + db = pymongo.MongoClient(MONGO_URI).osm + db.invocation.insert_many(transform_data(df)) except Exception as e: breakpoint() logger.error(f"Failed to process data: {e}") - raise e + # raise e + breakpoint() if __name__ == "__main__": - url = os.getenv( - "DATABASE_URL", - None, - ) - name = os.getenv("DATABASE_NAME", None) - main(url, name) + main()