diff --git a/notebooks/covid_eda_modular.py b/notebooks/covid_eda_modular.py index d162e11..c6e19ab 100644 --- a/notebooks/covid_eda_modular.py +++ b/notebooks/covid_eda_modular.py @@ -20,14 +20,14 @@ # COMMAND ---------- -from covid_analysis.transforms import * +from covid_analysis import transforms import pandas as pd df = pd.read_csv(data_path) -df = filter_country(df, country='DZA') -df = pivot_and_clean(df, fillna=0) -df = clean_spark_cols(df) -df = index_to_col(df, colname='date') +df = transforms.filter_country(df, country='DZA') +df = transforms.pivot_and_clean(df, fillna=0) +df = transforms.clean_spark_cols(df) +df = transforms.index_to_col(df, colname='date') # Convert from Pandas to a pyspark sql DataFrame. df = spark.createDataFrame(df) diff --git a/tests/transforms_test.py b/tests/transforms_test.py index 8fac7a0..33c20dc 100644 --- a/tests/transforms_test.py +++ b/tests/transforms_test.py @@ -4,7 +4,7 @@ import os import pandas as pd import numpy as np -from covid_analysis.transforms import * +from covid_analysis import transforms from pyspark.sql import SparkSession @@ -30,28 +30,28 @@ def colnames_df() -> pd.DataFrame: ], ) return df - + # Make sure the filter works as expected. def test_filter(raw_input_df): - filtered = filter_country(raw_input_df) + filtered = transforms.filter_country(raw_input_df) assert filtered.iso_code.drop_duplicates()[0] == "USA" # The test data has NaNs for Daily ICU occupancy; this should get filled to 0. def test_pivot(raw_input_df): - pivoted = pivot_and_clean(raw_input_df, 0) + pivoted = transforms.pivot_and_clean(raw_input_df, 0) assert pivoted["Daily ICU occupancy"][0] == 0 # Test column cleaning. def test_clean_cols(colnames_df): - cleaned = clean_spark_cols(colnames_df) + cleaned = transforms.clean_spark_cols(colnames_df) cols_w_spaces = cleaned.filter(regex=(" ")) assert cols_w_spaces.empty == True # Test column creation from index. def test_index_to_col(raw_input_df): - raw_input_df["col_from_index"] = raw_input_df.index + raw_input_df["col_from_index"] = transforms.index_to_col(raw_input_df) assert (raw_input_df.index == raw_input_df.col_from_index).all()