Skip to content

Commit

Permalink
Fixed pre-commit for all files | Fixed return types and types for som…
Browse files Browse the repository at this point in the history
…e functions with test cases
  • Loading branch information
spirosmaggioros committed Nov 18, 2024
1 parent 4800353 commit f020590
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 234 deletions.
97 changes: 0 additions & 97 deletions merge_ROI_demo_and_test.py

This file was deleted.

6 changes: 3 additions & 3 deletions spare_scores/data_prep.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import random
from typing import Any, Tuple, Union
from typing import Any, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -16,7 +16,7 @@ def check_train(
to_predict: str,
verbose: int = 1, # this needs to be removed(non used)
pos_group: str = "",
) -> Union[str, Tuple[pd.DataFrame, list, str]]:
) -> Union[Tuple[pd.DataFrame, list, str], str]:
"""
Checks training dataframe for errors.
Expand Down Expand Up @@ -221,7 +221,7 @@ def smart_unique(
def age_sex_match(
df1: pd.DataFrame,
df2: Union[pd.DataFrame, None] = None,
to_match: str = "",
to_match: Optional[str] = "",
p_threshold: float = 0.15,
verbose: int = 1,
age_out_percentage: float = 20,
Expand Down
4 changes: 2 additions & 2 deletions spare_scores/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
import pickle
from typing import Any, Union
from typing import Any, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,7 +43,7 @@ def add_file_extension(filename: str, extension: str) -> str:
return filename


def check_file_exists(filename: str, logger: Any) -> Any:
def check_file_exists(filename: Optional[str], logger: Any) -> Any:
"""
Checks if file exists
Expand Down
25 changes: 0 additions & 25 deletions tests/conftest.py

This file was deleted.

36 changes: 21 additions & 15 deletions tests/unit/test_data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@

class CheckDataPrep(unittest.TestCase):

def test_check_train(self):
def test_check_train(self) -> None:
# Test case 1: Valid input dataframe and predictors
self.df_fixture = load_df("../fixtures/sample_data.csv")
predictors = ["ROI1", "ROI2", "ROI3"]
to_predict = "Sex"
pos_group = "M"
filtered_df, filtered_predictors, mdl_type = check_train(
result = check_train(
self.df_fixture, predictors, to_predict, pos_group=pos_group
)
if isinstance(result, str):
self.fail("check_train returned an error")
else:
filtered_df, filtered_predictors, mdl_type = result
self.assertTrue(
filtered_df.equals(self.df_fixture)
) # Check if filtered dataframe is the same as the input dataframe
Expand All @@ -42,7 +46,9 @@ def test_check_train(self):
predictors = ["Var1", "Var2"]
to_predict = "ToPredict"
pos_group = "1"
res = check_train(df_missing_columns, predictors, to_predict, pos_group)
res = check_train(
df_missing_columns, predictors, to_predict, pos_group=pos_group
)
self.assertTrue(res == "Variable to predict is not in the input dataframe.")

# Test case 3: Predictor not in input dataframe
Expand All @@ -57,10 +63,10 @@ def test_check_train(self):
predictors = ["Var1", "Var2"] # Var2 is not in the input dataframe
to_predict = "ToPredict"
pos_group = "1"
res = check_train(df, predictors, to_predict, pos_group)
res = check_train(df, predictors, to_predict, pos_group=pos_group)
self.assertTrue(res == "Not all predictors exist in the input dataframe.")

def test_check_test(self):
def test_check_test(self) -> None:
# Test case 1: Valid input dataframe and meta_data
df = pd.DataFrame(
{
Expand Down Expand Up @@ -121,9 +127,9 @@ def test_check_test(self):
),
}
res = check_test(df_age_outside_range, meta_data)
self.assertTrue(res[1] == None)
self.assertTrue(res[1] is None)

def test_smart_unique(self):
def test_smart_unique(self) -> None:
# test case 1: testing smart_unique with df2=None, to_predict=None
self.df_fixture = load_df("../fixtures/sample_data.csv")
result = smart_unique(self.df_fixture, None)
Expand All @@ -146,8 +152,8 @@ def test_smart_unique(self):

# test case 3: testing smart_unique with variance and no duplicate ID's. df2=None
self.df_fixture = load_df("../fixtures/sample_data.csv")
result = smart_unique(self.df_fixture, None, "ROI1")
self.assertTrue(result.equals(self.df_fixture))
result_df: pd.DataFrame = smart_unique(self.df_fixture, None, "ROI1")
self.assertTrue(result_df.equals(self.df_fixture))

# test case 4: testing smart_unique with variance and duplicate ID's. df2=None
self.df_fixture = pd.DataFrame(data=df)
Expand All @@ -161,7 +167,7 @@ def test_smart_unique(self):
"ROI2": 0.73,
}
self.df_fixture = self.df_fixture._append(new_row, ignore_index=True)
result = smart_unique(self.df_fixture, None, "ROI1")
result_df_2: pd.DataFrame = smart_unique(self.df_fixture, None, "ROI1")
correct_df = {
"Id": [1.0, 2.0, 3.0, 4.0, 5.0, float("nan")],
"ScanID": [
Expand All @@ -186,7 +192,7 @@ def test_smart_unique(self):
],
}
correct_df = pd.DataFrame(data=correct_df)
self.assertTrue(result.equals(correct_df))
self.assertTrue(result_df_2.equals(correct_df))

# test case 5: testing df2 != None and no_df2=False
df1 = {
Expand All @@ -199,10 +205,10 @@ def test_smart_unique(self):
self.df_fixture1 = pd.DataFrame(data=df1)
self.df_fixture2 = pd.DataFrame(data=df2)

result = smart_unique(self.df_fixture1, self.df_fixture2, to_predict=None)
result = smart_unique(self.df_fixture1, self.df_fixture2, to_predict="")
self.assertTrue(result == (self.df_fixture1, self.df_fixture2))

def test_age_sex_match(self):
def test_age_sex_match(self) -> None:
# test case 1: testing df2=None and to_match=None
self.df_fixture = load_df("../fixtures/sample_data.csv")
result = age_sex_match(self.df_fixture, None)
Expand Down Expand Up @@ -265,7 +271,7 @@ def test_age_sex_match(self):
print(result)
self.assertTrue(result.equals(correct_df))

def test_logging_basic_config(self):
def test_logging_basic_config(self) -> None:
logging_level = {
0: logging.WARNING,
1: logging.INFO,
Expand All @@ -291,5 +297,5 @@ def test_logging_basic_config(self):
self.assertTrue(os.path.exists("test_data_prep.py"))
self.assertTrue(result == logging.getLogger())

def test_convert_cat_variables(self):
def test_convert_cat_variables(self) -> None:
pass
Loading

0 comments on commit f020590

Please sign in to comment.