diff --git a/osm/schemas.py b/osm/schemas.py index 1ee5f51f..17627a32 100644 --- a/osm/schemas.py +++ b/osm/schemas.py @@ -43,7 +43,7 @@ class Metrics(EmbeddedModel): class Client(EmbeddedModel): compute_context_id: int - email: Optional[EmailStr] + email: Optional[EmailStr] = None class Work(EmbeddedModel): @@ -62,8 +62,8 @@ class Work(EmbeddedModel): openalex_id: Optional[str] = None scopus_id: Optional[str] = None filename: str - file: str - content_hash: str + file: Optional[str] = None + content_hash: Optional[str] = None class Invocation(Model): @@ -81,8 +81,8 @@ class Invocation(Model): # components: list[Component] -# Rtransparent: -# Component.construct(name="rtransparent", version="0.13", docker_image="nimh-dsst/rtransparent:0.13", docker_image_id="dsjfkldsjflkdsjlf2jkl23j") -# Derivative.construct(name="rtransparent", version="0.13", docker_image="nimh-dsst/rtransparent:0.13", docker_image_id="dsjfkldsjflkdsjlf2jkl23j") -# ScibeamParser: -# Component.construct(name="scibeam-parser", version="0.5.1", docker_image="elife/scibeam-parser:0.5.1", docker_image_id="dsjfkldsjflkdsjlf2jkl23j") +# Rtransparent: Component.construct(name="rtransparent", version="0.13", docker_image="nimh-dsst/rtransparent:0.13", +# docker_image_id="dsjfkldsjflkdsjlf2jkl23j") Derivative.construct(name="rtransparent", version="0.13", +# docker_image="nimh-dsst/rtransparent:0.13", docker_image_id="dsjfkldsjflkdsjlf2jkl23j") ScibeamParser: +# Component.construct(name="scibeam-parser", version="0.5.1", docker_image="elife/scibeam-parser:0.5.1", +# docker_image_id="dsjfkldsjflkdsjlf2jkl23j") diff --git a/scripts/invocation_upload.py b/scripts/invocation_upload.py new file mode 100644 index 00000000..341fa402 --- /dev/null +++ b/scripts/invocation_upload.py @@ -0,0 +1,136 @@ +import logging +import os +import pickle +import tempfile +from pathlib import Path +from typing import List + +import pandas as pd +import requests + +# from motor.motor_tornado import MotorClient +from motor.motor_asyncio import AsyncIOMotorClient +from pydantic import ValidationError + +from osm.schemas import Client, Invocation, Work + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +# NOTICE: output of rt_all without corresponding values in the all_indicators.csv from Rtransparent publication +unmapped = { + "article": "", + "is_relevant": None, + "is_explicit": None, +} + + +def transform_data(df: pd.DataFrame) -> List[Invocation]: + """Handles data transformation as well as mapping""" + invocations = [] + for index, row in df.iterrows(): + try: + work = Work( + user_defined_id=str(row["doi"]), + pmid=str(row.pop("pmid")), + doi=str(row.pop("doi")), + openalex_id=None, + scopus_id=None, + filename=row.pop("filename"), + file=None, + content_hash=None, + ) + client = Client(compute_context_id=999, email=None) + + metrics = {**unmapped, **row.to_dict()} + invocation = Invocation( + metrics=metrics, + osm_version="0.0.1", + client=client, + work=work, + user_comment="Initial database seeding with data from Rtransparent publication", + ) + + invocations.append(invocation) + except (KeyError, ValidationError) as e: + if isinstance(e, KeyError): + raise KeyError(f"Error key not found in row {index}: {e}") + elif isinstance(e, ValidationError): + raise e + + return invocations + + +def read_data(data_path: str): + """Checks to see if url is a path or https to download or read file""" + try: + if data_path.startswith("https"): + csv = download_csv(data_path) + df = pd.read_csv(csv) + else: + df = pd.read_feather(data_path) + return df + except FileNotFoundError as e: + raise e + + +async def upload_data(invocations: List[Invocation], mongo_uri: str, db_name: str): + """upload invocations to MongoDB one after the other to prevent timeout""" + motor_client = AsyncIOMotorClient(mongo_uri) + try: + engine = motor_client(client=motor_client, database=db_name) + engine.save_all(invocations) + except (TypeError, Exception) as e: + if isinstance(e, TypeError): + raise TypeError(e) + elif isinstance(e, Exception): + raise Exception(f"Failed to upload data: {e}") + finally: + motor_client.close() + + +def download_csv(url): + """downloads file and store in a temp location""" + try: + response = requests.get(url) + if response.status_code == 200: + temp_file, temp_file_path = tempfile.mkstemp(suffix=".csv") + os.close(temp_file) # Close the file descriptor + with open(temp_file_path, "wb") as file: + file.write(response.content) + return temp_file_path + else: + raise Exception( + f"Failed to download CSV. Status code: {response.status_code}" + ) + except Exception as e: + raise e + + +def main(data_path="all_indicators.feather"): + try: + transformed_pickle = Path("invocations.pkl") + if transformed_pickle.exists(): + df = pickle.loads(transformed_pickle.read_bytes()) + else: + breakpoint() + df = read_data(data_path) + if not df.empty: + invocations = transform_data(df) + transformed_pickle.write_bytes(pickle.dumps(invocations)) + else: + raise Exception("Dataframe is empty") + db_url = os.getenv("DATABASE_URL", None) + db_name = os.getenv("DATABASE_NAME", None) + logger.warning(f"Uploading data to {db_url}") + upload_data(invocations, db_url, db_name) + except Exception as e: + breakpoint() + logger.error(f"Failed to process data: {e}") + raise e + + +if __name__ == "__main__": + main() diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/scripts/test_invocation_upload.py b/tests/scripts/test_invocation_upload.py new file mode 100644 index 00000000..386a0664 --- /dev/null +++ b/tests/scripts/test_invocation_upload.py @@ -0,0 +1,292 @@ +import re +from pydantic import ValidationError +from typing import List +import pandas as pd +from osm.schemas import Invocation +import pytest +from requests.models import Response +from scripts.invocation_upload import read_data, upload_data, transform_data, main +import logging + +invocations_row = [{ + "article": "tmph81dyo65", + "pmid": "/tmp/tmph81dyo65.xml", + "is_coi_pred": False, + "coi_text": "", + "is_funded_pred": True, + "fund_text": "fund_text", + "funding_text": "Reliability and predictability of phenotype information from functional connectivity in large imaging datasets...", + "support_1": True, + "support_3": True, + "support_4": False, + "support_5": True, + "support_6": False, + "support_7": False, + "support_8": False, + "support_9": False, + "support_10": False, + "developed_1": False, + "received_1": False, + "received_2": False, + "recipient_1": False, + "authors_1": False, + "authors_2": False, + "thank_1": False, + "thank_2": False, + "fund_1": False, + "fund_2": False, + "fund_3": False, + "supported_1": False, + "financial_1": False, + "financial_2": False, + "financial_3": False, + "grant_1": False, + "french_1": False, + "common_1": False, + "common_2": False, + "common_3": False, + "common_4": False, + "common_5": False, + "acknow_1": False, + "disclosure_1": False, + "disclosure_2": False, + "is_register_pred": False, + "register_text": "", + "is_relevant": True, + "is_method": False, + "is_NCT": -2147483648, + "is_explicit": -2147483648, + "doi": "test", + "filename": "test", + "is_fund_pred": True +}] + +# Define the list of column names based on the keys of the metrics dictionary +columns = [ + "article", # no mapping + "pmid", + "is_coi_pred", + "coi_text", + "is_fund_pred", + "is_funded_pred", + "funding_text", + "support_1", + "support_3", + "support_4", + "support_5", + "support_6", + "support_7", + "support_8", + "support_9", + "support_10", + "developed_1", + "received_1", + "received_2", + "recipient_1", + "fund_text", + "authors_1", + "authors_2", + "thank_1", + "thank_2", + "fund_1", + "fund_2", + "fund_3", + "supported_1", + "financial_1", + "financial_2", + "financial_3", + "grant_1", + "french_1", + "common_1", + "common_2", + "common_3", + "common_4", + "common_5", + "acknow_1", + "disclosure_1", + "disclosure_2", + "is_register_pred", + "register_text", + "is_relevant", + "is_method", + "is_NCT", + "is_explicit", + "doi", + "filename" +] + + +class MockAIOEngine: + def __init__(self, client, database): + self.client = client + self.database = database + + async def save(self, invocation): + pass + + +class MockAsyncIOMotorClient: + def __init__(self, uri): + self.uri = uri + + def close(self): + pass # Simulate closing the connection + + +@pytest.fixture +def mock_database(monkeypatch, mocker): + # Mock AsyncIOMotorClient + monkeypatch.setattr( + 'motor.motor_asyncio.AsyncIOMotorClient', MockAsyncIOMotorClient) + # Mock AIOEngine + monkeypatch.setattr('odmantic.AIOEngine', MockAIOEngine) + + +@pytest.fixture +def mock_read_data(mocker): + dataframe = pd.DataFrame(invocations_row, columns=columns) + mocker.patch("scripts.invocation_upload.read_data", + return_value=dataframe) + + +@pytest.fixture +def mock_update_data(mocker): + # Mock DataConverter functions + mocker.patch(upload_data, return_value={ + 'message': 'upload successful'}) + + +@pytest.fixture +def mock_response_success(monkeypatch): + def mock_get(*args, **kwargs): + mock_resp = Response() + mock_resp.status_code = 200 + mock_resp._content = b"column1,column2\nvalue1,value2" + return mock_resp + + monkeypatch.setattr("requests.get", mock_get) + + +@pytest.fixture +def mock_response_failure(monkeypatch): + def mock_get(*args, **kwargs): + mock_resp = Response() + mock_resp.status_code = 400 + mock_resp._content = "something went wrong" + return mock_resp + + monkeypatch.setattr("requests.get", mock_get) + + +@pytest.fixture +def mock_upload_data_success(monkeypatch, mocker): + async def mock_upload(*args, **kwargs): + print({'message': 'upload successful'}) + + monkeypatch.setattr("scripts.invocation_upload.upload_data", + mocker.AsyncMock(side_effect=mock_upload)) + + +def test_read_data_from_url(mock_response_success): + url = "https://example.com/data.csv" + df = read_data(url) + + # Verify the DataFrame contents + assert not df.empty + assert (df.columns == ["column1", "column2"]).all() + assert df.iloc[0]["column1"] == "value1" + + +def test_read_data_from_url_failure(mock_response_failure): + url = "https://example.com/data.csv" + + # Verify the DataFrame contents + with pytest.raises(Exception, match="Failed to download CSV. Status code: 400"): + read_data(url) + + +def test_read_data_from_feather(tmp_path): + # Create a temporary Feather file with sample data + df_sample = pd.DataFrame( + {"column1": ["value1"], "column2": ["value2"]}) + feather_file = tmp_path / "all_indicators.feather" + df_sample.to_feather(feather_file) + + df = read_data(str(feather_file)) + + # Verify the DataFrame contents + assert not df.empty + assert (df.columns == ["column1", "column2"]).all() + assert df.iloc[0]["column1"] == "value1" + + +def test_read_data_from_feather_failure(): + feather_file = "data.feather" + + # Verify the DataFrame contents + with pytest.raises(FileNotFoundError, match=re.escape( + "[Errno 2] No such file or directory: \'data.feather\'")): + read_data(str(feather_file)) + + +def test_transform_data(): + dataframe = pd.DataFrame(invocations_row, columns=columns) + data = transform_data(dataframe) + + assert len(data) == 1 + assert isinstance(data, List), "Data is not a list" + for item in data: + assert isinstance( + item, Invocation), "Item in list is not an instance of Invocation" + + +def test_transform_data_wrong_dataframe(): + df_sample = pd.DataFrame( + {"column1": ["value1"], "column2": ["value2"]}) + + with pytest.raises(KeyError, match="Error key not found in row 0: \'doi\'"): + transform_data(df_sample) + + +def test_transform_data_validation_error(): + df_sample = pd.DataFrame({"pmid": [0], "doi": [0], "filename": [0]}) + + with pytest.raises(ValidationError, match=re.escape( + '1 validation error for Work\nfilename\n Input should be a valid string [type=string_type, input_value=np.int64(0), input_type=int64]\n For further information visit https://errors.pydantic.dev/2.8/v/string_type')): + transform_data(df_sample) + + +@pytest.mark.asyncio +async def test_upload_data_success(mock_database): + # Define test data + invocation_list = [] + mongo_uri = "mongodb://test_uri" + db_name = "test_db" + + # If an exception is raised in the above call, the test will fail. + # There's no need for a 'with not pytest.raises(Exception):' block. + await upload_data(invocation_list, mongo_uri, db_name) + + +@pytest.mark.asyncio +async def test_upload_data_failure(mock_database, caplog): + invocation_list = [] + + with pytest.raises(TypeError, match="NoneType' object is not iterable"): + await upload_data(invocation_list, None, None) + + +def test_main(mock_read_data, mock_upload_data_success, capfd): + main(data_path="all_indicators.feather") + + out, err = capfd.readouterr() + + # Assert the expected output was printed + assert "{'message': 'upload successful'}" in out + assert not err + + +def test_main_failure(mock_read_data): + with pytest.raises(TypeError, match="NoneType' object is not iterable"): + main(data_path=None) +