Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin-folktables/main'
Browse files Browse the repository at this point in the history
# Conflicts:
#	folktables/load_acs.py
  • Loading branch information
baraldian committed Mar 25, 2024
2 parents 7474ac6 + 731b8d1 commit 1762722
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions folktables/load_acs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import numpy as np
import pandas as pd


state_list = ['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA', 'HI',
'ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD', 'MA', 'MI',
'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ', 'NM', 'NY', 'NC',
'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT',
'VT', 'VA', 'WA', 'WV', 'WI', 'WY', 'PR']


_STATE_CODES = {'AL': '01', 'AK': '02', 'AZ': '04', 'AR': '05', 'CA': '06',
'CO': '08', 'CT': '09', 'DE': '10', 'FL': '12', 'GA': '13',
'HI': '15', 'ID': '16', 'IL': '17', 'IN': '18', 'IA': '19',
Expand All @@ -33,10 +35,10 @@ def download_and_extract(url, datadir, remote_fname, file_name, delete_download=
response = requests.get(url)
with open(download_path, 'wb') as handle:
handle.write(response.content)

with zipfile.ZipFile(download_path, 'r') as zip_ref:
zip_ref.extract(file_name, path=datadir)

if delete_download and download_path != os.path.join(datadir, file_name):
os.remove(download_path)

Expand All @@ -55,26 +57,24 @@ def initialize_and_download(datadir, state, year, horizon, survey, download=Fals
else:
# 2016 and earlier use different file names
file_name = f'ss{str(year)[-2:]}{survey_code}{state.lower()}.csv'

# Assume is the path exists and is a file, then it has been downloaded
# correctly
file_path = os.path.join(datadir, file_name)
if os.path.isfile(file_path):
return file_path
if not download:
raise FileNotFoundError(
f'Could not find {year} {horizon} {survey} survey data for {state} in {datadir}. Call get_data with download=True to download the dataset.')

raise FileNotFoundError(f'Could not find {year} {horizon} {survey} survey data for {state} in {datadir}. Call get_data with download=True to download the dataset.')

print(f'Downloading data for {year} {horizon} {survey} survey for {state}...')
# Download and extract file
base_url = f'https://www2.census.gov/programs-surveys/acs/data/pums/{year}/{horizon}'
base_url= f'https://www2.census.gov/programs-surveys/acs/data/pums/{year}/{horizon}'
remote_fname = f'csv_{survey_code}{state.lower()}.zip'
url = f'{base_url}/{remote_fname}'
try:
download_and_extract(url, datadir, remote_fname, file_name, delete_download=True)
except Exception as e:
print(
f'\n{os.path.join(datadir, remote_fname)} may be corrupted. Please try deleting it and rerunning this command.\n')
print(f'\n{os.path.join(datadir, remote_fname)} may be corrupted. Please try deleting it and rerunning this command.\n')
print(f'Exception: ', e)

return file_path
Expand All @@ -99,12 +99,12 @@ def load_acs(root_dir, states=None, year=2018, horizon='1-Year',

if states is None:
states = state_list

random.seed(random_seed)

base_datadir = os.path.join(root_dir, str(year), horizon)
os.makedirs(base_datadir, exist_ok=True)

file_names = []
for state in states:
file_names.append(
Expand All @@ -114,7 +114,7 @@ def load_acs(root_dir, states=None, year=2018, horizon='1-Year',
dtypes = {'PINCP': np.float64, 'RT': str, 'SOCP': str, 'SERIALNO': str, 'NAICSP': str}
df_list = []
for file_name in file_names:
df = pd.read_csv(file_name, dtype=dtypes, engine="c").replace(' ','')
df = pd.read_csv(file_name, dtype=dtypes).replace(' ','')
if serial_filter_list is not None:
df = df[df['SERIALNO'].isin(serial_filter_list)]
df_list.append(df)
Expand Down Expand Up @@ -191,4 +191,4 @@ def generate_categories(features, definition_df):
del mapping_dict[-99999999999999.0]

categories[feature] = mapping_dict
return categories
return categories

0 comments on commit 1762722

Please sign in to comment.