Skip to content

Commit

Permalink
add covid-19 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Oct 31, 2023
1 parent bad59b2 commit e2884c2
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 4 deletions.
11 changes: 11 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,14 @@ def test_download(mocker, caplog):
assert "Cannot download dataset" in str(excinfo.value)
finally:
os.remove(resource_path)


def test_get_covid_19():
X, graph, states = tsgm.utils.get_covid_19()
assert len(states) == 51 and "new york" in states and "california" in states
assert len(graph[0]) == len(states) # nodes
assert len(graph[1]) == 220 # edges
assert X.shape[0] == len(states)
assert len(X.shape) == 3
assert X.shape[2] == 4
assert X.shape[1] >= 150
1 change: 1 addition & 0 deletions tsgm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tsgm.utils.file_utils import * # noqa
from tsgm.utils.data_processing import * # noqa
from tsgm.utils.visualization import * # noqa
from tsgm.utils.covid19_data_utils import * # noqa
from tsgm.utils.datasets import * # noqa
from tsgm.utils.utils import * # noqa
from tsgm.utils.mmd import * # noqa
186 changes: 186 additions & 0 deletions tsgm/utils/covid19_data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Utils for COVID-19 graph time series dataset:
The dataset is based on data from The New York Times, based on reports from state and local health agencies [1].
And was adapted to graph case in [2].
[1] The New York Times. (2021). Coronavirus (Covid-19) Data in the United States. Retrieved [Insert Date Here], from https://github.com/nytimes/covid-19-data.
[2]
The code is an adapted version from:
https://github.com/AlexanderVNikitin/covid19-on-graphs
"""

import pandas as pd


STATE_ADJACENCIES = {
"washington": ["oregon", "idaho"],
"oregon": ["washington", "idaho", "nevada", "california"],
"california": ["oregon", "nevada", "arizona"],
"idaho": ["washington", "montana", "wyoming", "utah", "nevada", "oregon"],
"montana": ["north dakota", "south dakota", "wyoming", "idaho"],
"north dakota": ["minnesota", "south dakota", "montana"],
"south dakota": ["north dakota", "minnesota", "iowa", "nebraska", "wyoming", "montana"],
"minnesota": ["wisconsin", "iowa", "south dakota", "north dakota"],
"michigan": ["indiana", "ohio", "wisconsin"],
"ohio": ["michigan", "pennsylvania", "west virginia", "kentucky", "indiana"],
"pennsylvania": ["new york", "new jersey", "delaware", "maryland", "west virginia", "ohio"],
"new york": ["vermont", "massachusetts", "rhode island", "new jersey", "pennsylvania", "connecticut"],
"vermont": ["new hampshire", "massachusetts", "new york"],
"new hampshire": ["maine", "massachusetts", "vermont"],
"maine": ["new hampshire"],
"wyoming": ["montana", "south dakota", "nebraska", "colorado", "utah", "idaho"],
"nebraska": ["south dakota", "iowa", "missouri", "kansas", "colorado", "wyoming"],
"iowa": ["minnesota", "wisconsin", "illinois", "missouri", "nebraska", "south dakota"],
"wisconsin": ["minnesota", "iowa", "illinois", "michigan"],
"illinois": ["wisconsin", "indiana", "kentucky", "missouri", "iowa"],
"indiana": ["michigan", "ohio", "kentucky", "illinois"],
"west virginia": ["ohio", "pennsylvania", "maryland", "virginia", "kentucky"],
"maryland": ["delaware", "pennsylvania", "west virginia", "virginia", "district of columbia"],
"delaware": ["maryland", "pennsylvania", "new jersey"],
"new jersey": ["delaware", "pennsylvania", "new york"],
"connecticut": ["new york", "massachusetts", "rhode island"],
"rhode island": ["connecticut", "massachusetts", "new york"],
"district of columbia": ["maryland", "virginia"],
"virginia": ["west virginia", "kentucky", "district of columbia", "maryland", "north carolina", "tennessee"],
"kentucky": ["indiana", "ohio", "west virginia", "virginia", "tennessee", "missouri", "illinois"],
"missouri": ["iowa", "illinois", "kentucky", "tennessee", "arkansas", "oklahoma", "kansas", "nebraska"],
"kansas": ["nebraska", "missouri", "oklahoma", "colorado"],
"colorado": ["wyoming", "nebraska", "kansas", "oklahoma", "new mexico", "utah", "arizona"],
"utah": ["idaho", "wyoming", "colorado", "new mexico", "arizona", "nevada"],
"nevada": ["oregon", "idaho", "utah", "arizona", "california"],
"arizona": ["california", "nevada", "utah", "colorado", "new mexico"],
"new mexico": ["arizona", "utah", "colorado", "oklahoma", "texas"],
"oklahoma": ["colorado", "kansas", "missouri", "arkansas", "texas", "new mexico"],
"texas": ["new mexico", "oklahoma", "arkansas", "louisiana"],
"arkansas": ["oklahoma", "missouri", "tennessee", "mississippi", "louisiana", "texas"],
"louisiana": ["texas", "arkansas", "mississippi"],
"mississippi": ["louisiana", "arkansas", "tennessee", "alabama"],
"tennessee": ["missouri", "kentucky", "virginia", "north carolina", "georgia", "alabama", "mississippi", "arkansas"],
"alabama": ["mississippi", "tennessee", "georgia", "florida"],
"georgia": ["tennessee", "north carolina", "south carolina", "florida", "alabama"],
"florida": ["alabama", "georgia"],
"south carolina": ["georgia", "north carolina"],
"north carolina": ["south carolina", "tennessee", "virginia", "georgia"],
"alaska": [],
"hawaii": [],
"massachusetts": ["new york", "vermont", "new hampshire", "rhode island", "connecticut"],
}

LIST_OF_STATES = sorted(STATE_ADJACENCIES.keys())

# July 1 2019
STATE_POPULATION = {
"california": 39_512_223,
"texas": 28_995_881,
"florida": 21_477_737,
"new york": 19_453_561,
"pennsylvania": 12_801_989,
"illinois": 12_671_821,
"ohio": 11_689_100,
"georgia": 10_617_423,
"north carolina": 10_488_084,
"michigan": 9_986_857,
"new jersey": 8_882_190,
"virginia": 8_535_519,
"washington": 7_614_893,
"arizona": 7_278_717,
"massachusetts": 6_949_503,
"tennessee": 6_833_174,
"indiana": 6_732_219,
"missouri": 6_137_428,
"maryland": 6_045_680,
"wisconsin": 5_822_434,
"colorado": 5_758_736,
"minnesota": 5_639_632,
"south carolina": 5_148_714,
"alabama": 4_903_185,
"louisiana": 4_648_794,
"kentucky": 4_467_673,
"oregon": 4_217_737,
"oklahoma": 3_956_971,
"connecticut": 3_565_287,
"utah": 3_205_958,
"iowa": 3_155_070,
"nevada": 3_080_156,
"arkansas": 3_017_825,
"mississippi": 2_976_149,
"kansas": 2_913_314,
"new mexico": 2_096_829,
"nebraska": 1_934_408,
"west virginia": 1_792_147,
"idaho": 1_787_065,
"hawaii": 1_415_872,
"new hampshire": 1_359_711,
"maine": 1_344_212,
"montana": 1_068_778,
"rhode island": 1_059_361,
"delaware": 973_764,
"south dakota": 884_659,
"north dakota": 762_062,
"alaska": 731_545,
"district of columbia": 705_749,
"vermont": 623_989,
"wyoming": 578_759,
"virgin islands": 104_914,
"puerto rico": 3_193_694,
"guam": 165_718,
}


def aggregate_by_weeks_max(df):
df['date'] = pd.to_datetime(df['date']) # + pd.to_timedelta(7, unit='d')
df = df.groupby(['state', pd.Grouper(key='date', freq='W-MON')])\
.agg({"cases": max, "deaths": max})\
.reset_index()\
.sort_values('date')
return df


def get_adjacencies_graph():
nodes, edges = [], []
LIST_OF_STATES = sorted(STATE_ADJACENCIES.keys())

for state_name in LIST_OF_STATES:
nodes.append(state_name)

for state, adj_states in STATE_ADJACENCIES.items():
for adj_state in adj_states:
edges.append((state, adj_state))
return nodes, edges


def covid_dataset(path):
covid_cases_df = pd.read_csv(path)
covid_cases_df["state"] = covid_cases_df["state"].str.lower()
covid_cases_df = aggregate_by_weeks_max(covid_cases_df)
graph = get_adjacencies_graph()
result = {}
for row in covid_cases_df.to_dict(orient="records"):
date = row["date"]
cases = row["cases"]
deaths = row["deaths"]
state = row["state"]
if date not in result:
result[date] = {}
if state in STATE_POPULATION:
result[date][state] = {
"deaths_normalized": deaths / STATE_POPULATION[state],
"cases_normalized": cases / STATE_POPULATION[state],
"deaths": deaths,
"cases": cases,
}
else:
print("[WARNING]: There is no data about population for: ", state)

# fill missing values with zeros
for date in result.keys():
for state in LIST_OF_STATES:
if state not in result[date]:
result[date][state] = {
"deaths": 0,
"cases": 0,
"deaths_normalized": 0,
"cases_normalized": 0,
}
return result, graph
45 changes: 41 additions & 4 deletions tsgm/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

from tensorflow import keras

from tsgm.utils import covid19_data_utils
from tsgm.utils import file_utils


logger = logging.getLogger('utils')
logger.setLevel(logging.DEBUG)


def gen_sine_dataset(N, T, D, max_value=10):
def gen_sine_dataset(N: int, T: int, D: int, max_value: int = 10) -> np.ndarray:
result = []
for i in range(N):
result.append([])
Expand All @@ -35,7 +36,7 @@ def gen_sine_dataset(N, T, D, max_value=10):
return np.transpose(np.array(result), [0, 2, 1])


def gen_sine_const_switch_dataset(N, T, D, max_value=10, const=0, frequency_switch=0.1):
def gen_sine_const_switch_dataset(N: int, T: int, D: int, max_value: int = 10, const: int = 0, frequency_switch: float = 0.1) -> tuple:
result_X, result_y = [], []
cur_y = 0
scales = np.random.random(D) * max_value
Expand Down Expand Up @@ -185,7 +186,7 @@ def get_mauna_loa() -> tuple:
return X, y


def split_dataset_into_objects(X, y, step=10):
def split_dataset_into_objects(X, y, step=10) -> tuple:
assert X.shape[0] == y.shape[0]

Xs, ys = [], []
Expand Down Expand Up @@ -293,7 +294,7 @@ def get_physionet2012() -> tuple:
return train_X, train_y, test_X, test_y, val_X, val_y


def download_physionet2012():
def download_physionet2012() -> None:
"""
Downloads the Physionet 2012 dataset files from the Physionet website
and extracts them in local folder 'physionet2012'
Expand Down Expand Up @@ -359,3 +360,39 @@ def _get_physionet_y_dataframe(file_path: str) -> pd.DataFrame:
y.index.name = 'recordid'
y.reset_index(inplace=True)
return y


def get_covid_19() -> tuple:
"""
Loads Covid-19 dataset with additional graph information
The dataset is based on data from The New York Times, based on reports from state and local health agencies [1].
And was adapted to graph case in [2].
[1] The New York Times. (2021). Coronavirus (Covid-19) Data in the United States. Retrieved [Insert Date Here], from https://github.com/nytimes/covid-19-data.
[2] Alexander V. Nikitin, St John, Arno Solin, Samuel Kaski Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, PMLR 151:10640-10660, 2022.
Returns:
-------
tuple
First element is time series data (n_nodes x n_timestamps x n_features). Each timestamp consists of
the number of deaths, cases, deaths normalized by the population, and cases normalized by the population.
The second element is the graph tuple (nodes, edges).
The third element is the order of states.
"""
base_url = "https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-states.csv"
destination_folder = "covid19"
file_utils.download(base_url, destination_folder)
result, graph = covid19_data_utils.covid_dataset(
os.path.join(destination_folder, "us-states.csv")
)

processed_dataset = []
for timestamp in result.keys():
processed_dataset.append([])
for state in covid19_data_utils.LIST_OF_STATES:
cur_data = result[timestamp][state]
processed_dataset[-1].append(
[cur_data["deaths"], cur_data["cases"],
cur_data["deaths_normalized"], cur_data["cases_normalized"]]
)
return np.transpose(np.array(processed_dataset), (1, 0, 2)), graph, covid19_data_utils.LIST_OF_STATES

0 comments on commit e2884c2

Please sign in to comment.