diff --git a/tests/test_utils.py b/tests/test_utils.py index d8f605e..0255a16 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -283,3 +283,14 @@ def test_download(mocker, caplog): assert "Cannot download dataset" in str(excinfo.value) finally: os.remove(resource_path) + + +def test_get_covid_19(): + X, graph, states = tsgm.utils.get_covid_19() + assert len(states) == 51 and "new york" in states and "california" in states + assert len(graph[0]) == len(states) # nodes + assert len(graph[1]) == 220 # edges + assert X.shape[0] == len(states) + assert len(X.shape) == 3 + assert X.shape[2] == 4 + assert X.shape[1] >= 150 \ No newline at end of file diff --git a/tsgm/utils/__init__.py b/tsgm/utils/__init__.py index ac7a141..c87f8b1 100644 --- a/tsgm/utils/__init__.py +++ b/tsgm/utils/__init__.py @@ -1,6 +1,7 @@ from tsgm.utils.file_utils import * # noqa from tsgm.utils.data_processing import * # noqa from tsgm.utils.visualization import * # noqa +from tsgm.utils.covid19_data_utils import * # noqa from tsgm.utils.datasets import * # noqa from tsgm.utils.utils import * # noqa from tsgm.utils.mmd import * # noqa diff --git a/tsgm/utils/covid19_data_utils.py b/tsgm/utils/covid19_data_utils.py new file mode 100644 index 0000000..31c4e9c --- /dev/null +++ b/tsgm/utils/covid19_data_utils.py @@ -0,0 +1,186 @@ +""" +Utils for COVID-19 graph time series dataset: +The dataset is based on data from The New York Times, based on reports from state and local health agencies [1]. + +And was adapted to graph case in [2]. +[1] The New York Times. (2021). Coronavirus (Covid-19) Data in the United States. Retrieved [Insert Date Here], from https://github.com/nytimes/covid-19-data. +[2] + +The code is an adapted version from: +https://github.com/AlexanderVNikitin/covid19-on-graphs +""" + +import pandas as pd + + +STATE_ADJACENCIES = { + "washington": ["oregon", "idaho"], + "oregon": ["washington", "idaho", "nevada", "california"], + "california": ["oregon", "nevada", "arizona"], + "idaho": ["washington", "montana", "wyoming", "utah", "nevada", "oregon"], + "montana": ["north dakota", "south dakota", "wyoming", "idaho"], + "north dakota": ["minnesota", "south dakota", "montana"], + "south dakota": ["north dakota", "minnesota", "iowa", "nebraska", "wyoming", "montana"], + "minnesota": ["wisconsin", "iowa", "south dakota", "north dakota"], + "michigan": ["indiana", "ohio", "wisconsin"], + "ohio": ["michigan", "pennsylvania", "west virginia", "kentucky", "indiana"], + "pennsylvania": ["new york", "new jersey", "delaware", "maryland", "west virginia", "ohio"], + "new york": ["vermont", "massachusetts", "rhode island", "new jersey", "pennsylvania", "connecticut"], + "vermont": ["new hampshire", "massachusetts", "new york"], + "new hampshire": ["maine", "massachusetts", "vermont"], + "maine": ["new hampshire"], + "wyoming": ["montana", "south dakota", "nebraska", "colorado", "utah", "idaho"], + "nebraska": ["south dakota", "iowa", "missouri", "kansas", "colorado", "wyoming"], + "iowa": ["minnesota", "wisconsin", "illinois", "missouri", "nebraska", "south dakota"], + "wisconsin": ["minnesota", "iowa", "illinois", "michigan"], + "illinois": ["wisconsin", "indiana", "kentucky", "missouri", "iowa"], + "indiana": ["michigan", "ohio", "kentucky", "illinois"], + "west virginia": ["ohio", "pennsylvania", "maryland", "virginia", "kentucky"], + "maryland": ["delaware", "pennsylvania", "west virginia", "virginia", "district of columbia"], + "delaware": ["maryland", "pennsylvania", "new jersey"], + "new jersey": ["delaware", "pennsylvania", "new york"], + "connecticut": ["new york", "massachusetts", "rhode island"], + "rhode island": ["connecticut", "massachusetts", "new york"], + "district of columbia": ["maryland", "virginia"], + "virginia": ["west virginia", "kentucky", "district of columbia", "maryland", "north carolina", "tennessee"], + "kentucky": ["indiana", "ohio", "west virginia", "virginia", "tennessee", "missouri", "illinois"], + "missouri": ["iowa", "illinois", "kentucky", "tennessee", "arkansas", "oklahoma", "kansas", "nebraska"], + "kansas": ["nebraska", "missouri", "oklahoma", "colorado"], + "colorado": ["wyoming", "nebraska", "kansas", "oklahoma", "new mexico", "utah", "arizona"], + "utah": ["idaho", "wyoming", "colorado", "new mexico", "arizona", "nevada"], + "nevada": ["oregon", "idaho", "utah", "arizona", "california"], + "arizona": ["california", "nevada", "utah", "colorado", "new mexico"], + "new mexico": ["arizona", "utah", "colorado", "oklahoma", "texas"], + "oklahoma": ["colorado", "kansas", "missouri", "arkansas", "texas", "new mexico"], + "texas": ["new mexico", "oklahoma", "arkansas", "louisiana"], + "arkansas": ["oklahoma", "missouri", "tennessee", "mississippi", "louisiana", "texas"], + "louisiana": ["texas", "arkansas", "mississippi"], + "mississippi": ["louisiana", "arkansas", "tennessee", "alabama"], + "tennessee": ["missouri", "kentucky", "virginia", "north carolina", "georgia", "alabama", "mississippi", "arkansas"], + "alabama": ["mississippi", "tennessee", "georgia", "florida"], + "georgia": ["tennessee", "north carolina", "south carolina", "florida", "alabama"], + "florida": ["alabama", "georgia"], + "south carolina": ["georgia", "north carolina"], + "north carolina": ["south carolina", "tennessee", "virginia", "georgia"], + "alaska": [], + "hawaii": [], + "massachusetts": ["new york", "vermont", "new hampshire", "rhode island", "connecticut"], +} + +LIST_OF_STATES = sorted(STATE_ADJACENCIES.keys()) + +# July 1 2019 +STATE_POPULATION = { + "california": 39_512_223, + "texas": 28_995_881, + "florida": 21_477_737, + "new york": 19_453_561, + "pennsylvania": 12_801_989, + "illinois": 12_671_821, + "ohio": 11_689_100, + "georgia": 10_617_423, + "north carolina": 10_488_084, + "michigan": 9_986_857, + "new jersey": 8_882_190, + "virginia": 8_535_519, + "washington": 7_614_893, + "arizona": 7_278_717, + "massachusetts": 6_949_503, + "tennessee": 6_833_174, + "indiana": 6_732_219, + "missouri": 6_137_428, + "maryland": 6_045_680, + "wisconsin": 5_822_434, + "colorado": 5_758_736, + "minnesota": 5_639_632, + "south carolina": 5_148_714, + "alabama": 4_903_185, + "louisiana": 4_648_794, + "kentucky": 4_467_673, + "oregon": 4_217_737, + "oklahoma": 3_956_971, + "connecticut": 3_565_287, + "utah": 3_205_958, + "iowa": 3_155_070, + "nevada": 3_080_156, + "arkansas": 3_017_825, + "mississippi": 2_976_149, + "kansas": 2_913_314, + "new mexico": 2_096_829, + "nebraska": 1_934_408, + "west virginia": 1_792_147, + "idaho": 1_787_065, + "hawaii": 1_415_872, + "new hampshire": 1_359_711, + "maine": 1_344_212, + "montana": 1_068_778, + "rhode island": 1_059_361, + "delaware": 973_764, + "south dakota": 884_659, + "north dakota": 762_062, + "alaska": 731_545, + "district of columbia": 705_749, + "vermont": 623_989, + "wyoming": 578_759, + "virgin islands": 104_914, + "puerto rico": 3_193_694, + "guam": 165_718, +} + + +def aggregate_by_weeks_max(df): + df['date'] = pd.to_datetime(df['date']) # + pd.to_timedelta(7, unit='d') + df = df.groupby(['state', pd.Grouper(key='date', freq='W-MON')])\ + .agg({"cases": max, "deaths": max})\ + .reset_index()\ + .sort_values('date') + return df + + +def get_adjacencies_graph(): + nodes, edges = [], [] + LIST_OF_STATES = sorted(STATE_ADJACENCIES.keys()) + + for state_name in LIST_OF_STATES: + nodes.append(state_name) + + for state, adj_states in STATE_ADJACENCIES.items(): + for adj_state in adj_states: + edges.append((state, adj_state)) + return nodes, edges + + +def covid_dataset(path): + covid_cases_df = pd.read_csv(path) + covid_cases_df["state"] = covid_cases_df["state"].str.lower() + covid_cases_df = aggregate_by_weeks_max(covid_cases_df) + graph = get_adjacencies_graph() + result = {} + for row in covid_cases_df.to_dict(orient="records"): + date = row["date"] + cases = row["cases"] + deaths = row["deaths"] + state = row["state"] + if date not in result: + result[date] = {} + if state in STATE_POPULATION: + result[date][state] = { + "deaths_normalized": deaths / STATE_POPULATION[state], + "cases_normalized": cases / STATE_POPULATION[state], + "deaths": deaths, + "cases": cases, + } + else: + print("[WARNING]: There is no data about population for: ", state) + + # fill missing values with zeros + for date in result.keys(): + for state in LIST_OF_STATES: + if state not in result[date]: + result[date][state] = { + "deaths": 0, + "cases": 0, + "deaths_normalized": 0, + "cases_normalized": 0, + } + return result, graph diff --git a/tsgm/utils/datasets.py b/tsgm/utils/datasets.py index 893b402..c507413 100644 --- a/tsgm/utils/datasets.py +++ b/tsgm/utils/datasets.py @@ -15,6 +15,7 @@ from tensorflow import keras +from tsgm.utils import covid19_data_utils from tsgm.utils import file_utils @@ -22,7 +23,7 @@ logger.setLevel(logging.DEBUG) -def gen_sine_dataset(N, T, D, max_value=10): +def gen_sine_dataset(N: int, T: int, D: int, max_value: int = 10) -> np.ndarray: result = [] for i in range(N): result.append([]) @@ -35,7 +36,7 @@ def gen_sine_dataset(N, T, D, max_value=10): return np.transpose(np.array(result), [0, 2, 1]) -def gen_sine_const_switch_dataset(N, T, D, max_value=10, const=0, frequency_switch=0.1): +def gen_sine_const_switch_dataset(N: int, T: int, D: int, max_value: int = 10, const: int = 0, frequency_switch: float = 0.1) -> tuple: result_X, result_y = [], [] cur_y = 0 scales = np.random.random(D) * max_value @@ -185,7 +186,7 @@ def get_mauna_loa() -> tuple: return X, y -def split_dataset_into_objects(X, y, step=10): +def split_dataset_into_objects(X, y, step=10) -> tuple: assert X.shape[0] == y.shape[0] Xs, ys = [], [] @@ -293,7 +294,7 @@ def get_physionet2012() -> tuple: return train_X, train_y, test_X, test_y, val_X, val_y -def download_physionet2012(): +def download_physionet2012() -> None: """ Downloads the Physionet 2012 dataset files from the Physionet website and extracts them in local folder 'physionet2012' @@ -359,3 +360,39 @@ def _get_physionet_y_dataframe(file_path: str) -> pd.DataFrame: y.index.name = 'recordid' y.reset_index(inplace=True) return y + + +def get_covid_19() -> tuple: + """ + Loads Covid-19 dataset with additional graph information + The dataset is based on data from The New York Times, based on reports from state and local health agencies [1]. + + And was adapted to graph case in [2]. + [1] The New York Times. (2021). Coronavirus (Covid-19) Data in the United States. Retrieved [Insert Date Here], from https://github.com/nytimes/covid-19-data. + [2] Alexander V. Nikitin, St John, Arno Solin, Samuel Kaski Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, PMLR 151:10640-10660, 2022. + + Returns: + ------- + tuple + First element is time series data (n_nodes x n_timestamps x n_features). Each timestamp consists of + the number of deaths, cases, deaths normalized by the population, and cases normalized by the population. + The second element is the graph tuple (nodes, edges). + The third element is the order of states. + """ + base_url = "https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-states.csv" + destination_folder = "covid19" + file_utils.download(base_url, destination_folder) + result, graph = covid19_data_utils.covid_dataset( + os.path.join(destination_folder, "us-states.csv") + ) + + processed_dataset = [] + for timestamp in result.keys(): + processed_dataset.append([]) + for state in covid19_data_utils.LIST_OF_STATES: + cur_data = result[timestamp][state] + processed_dataset[-1].append( + [cur_data["deaths"], cur_data["cases"], + cur_data["deaths_normalized"], cur_data["cases_normalized"]] + ) + return np.transpose(np.array(processed_dataset), (1, 0, 2)), graph, covid19_data_utils.LIST_OF_STATES