diff --git a/folktables/acs.py b/folktables/acs.py index f6d6d03..dedd209 100644 --- a/folktables/acs.py +++ b/folktables/acs.py @@ -65,6 +65,9 @@ def get_definitions(self, download=False): return load_definitions(root_dir=self._root_dir, year=self._survey_year, horizon=self._horizon, download=download) +def fillna_safe(x, value=-1): + x = np.nan_to_num(x, value) + return pd.DataFrame(x).fillna(value=value).values def adult_filter(data): """Mimic the filters in place for Adult data. @@ -98,7 +101,7 @@ def adult_filter(data): target_transform=lambda x: x > 50000, group='RAC1P', preprocess=adult_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) ACSEmployment = folktables.BasicProblem( @@ -124,7 +127,7 @@ def adult_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=lambda x: x, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) ACSHealthInsurance = folktables.BasicProblem( @@ -159,7 +162,7 @@ def adult_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=lambda x: x, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) def public_coverage_filter(data): @@ -197,7 +200,7 @@ def public_coverage_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=public_coverage_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) def travel_time_filter(data): @@ -233,7 +236,7 @@ def travel_time_filter(data): target_transform=lambda x: x > 20, group='RAC1P', preprocess=travel_time_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) @@ -274,7 +277,7 @@ def mobility_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=mobility_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) def employment_filter(data): @@ -311,7 +314,7 @@ def employment_filter(data): target_transform=lambda x: x == 1, group='RAC1P', preprocess=employment_filter, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) ACSIncomePovertyRatio = folktables.BasicProblem( @@ -341,5 +344,5 @@ def employment_filter(data): target_transform=lambda x: x < 250, group='RAC1P', preprocess=lambda x: x, - postprocess=lambda x: np.nan_to_num(x, -1), + postprocess=fillna_safe, ) diff --git a/folktables/load_acs.py b/folktables/load_acs.py index 9464a71..f954a84 100644 --- a/folktables/load_acs.py +++ b/folktables/load_acs.py @@ -144,6 +144,7 @@ def load_definitions(root_dir, year=2018, horizon='1-Year', download=False): year_string = year if horizon == '1-Year' else f'{year - 4}-{year}' url = f'https://www2.census.gov/programs-surveys/acs/tech_docs/pums/data_dict/PUMS_Data_Dictionary_{year_string}.csv' + os.makedirs(base_datadir, exist_ok=True) response = requests.get(url) with open(file_path, 'wb') as handle: handle.write(response.content)