From 60b42c98d725faaa62a360651f10a181ad7085c0 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 9 Oct 2023 23:36:52 +0800 Subject: [PATCH] feat: return ICUType in the physionet_2012 dataset; --- pypots/data/generating.py | 14 ++++++++++++++ pypots/data/load_preprocessing.py | 7 +++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pypots/data/generating.py b/pypots/data/generating.py index f0a20473..4b462c2c 100644 --- a/pypots/data/generating.py +++ b/pypots/data/generating.py @@ -350,6 +350,8 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): dataset = load_specific_dataset("physionet_2012") X = dataset["X"] y = dataset["y"] + ICUType = dataset["ICUType"] + all_recordID = X["RecordID"].unique() train_set_ids, test_set_ids = train_test_split(all_recordID, test_size=0.2) train_set_ids, val_set_ids = train_test_split(train_set_ids, test_size=0.2) @@ -385,16 +387,28 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1): test_y = y[y.index.isin(test_set_ids)].sort_index() train_y, val_y, test_y = train_y.to_numpy(), val_y.to_numpy(), test_y.to_numpy() + train_ICUType = ICUType[ICUType.index.isin(train_set_ids)].sort_index() + val_ICUType = ICUType[ICUType.index.isin(val_set_ids)].sort_index() + test_ICUType = ICUType[ICUType.index.isin(test_set_ids)].sort_index() + train_ICUType, val_ICUType, test_ICUType = ( + train_ICUType.to_numpy(), + val_ICUType.to_numpy(), + test_ICUType.to_numpy(), + ) + data = { "n_classes": 2, "n_steps": 48, "n_features": train_X.shape[-1], "train_X": train_X, "train_y": train_y.flatten(), + "train_ICUType": train_ICUType.flatten(), "val_X": val_X, "val_y": val_y.flatten(), + "val_ICUType": val_ICUType.flatten(), "test_X": test_X, "test_y": test_y.flatten(), + "test_ICUType": test_ICUType.flatten(), "scaler": scaler, } diff --git a/pypots/data/load_preprocessing.py b/pypots/data/load_preprocessing.py index 0968ab5b..789cc0ce 100644 --- a/pypots/data/load_preprocessing.py +++ b/pypots/data/load_preprocessing.py @@ -25,7 +25,8 @@ def preprocess_physionet2012(data: dict) -> dict: y : pandas.Series The 11988 classification labels of all patients, indicating whether they were deceased. """ - # remove the static features, e.g. age, gender + data["static_features"].remove("ICUType") # keep ICUType for now + # remove the other static features, e.g. age, gender X = data["X"].drop(data["static_features"], axis=1) def apply_func(df_temp): # pad and truncate to set the max length of samples as 48 @@ -41,11 +42,13 @@ def apply_func(df_temp): # pad and truncate to set the max length of samples as X = X.groupby("RecordID").apply(apply_func) X = X.drop("RecordID", axis=1) X = X.reset_index() - X = X.drop(["level_1"], axis=1) + ICUType = X[["RecordID", "ICUType"]].set_index("RecordID").dropna() + X = X.drop(["level_1", "ICUType"], axis=1) dataset = { "X": X, "y": data["y"], + "ICUType": ICUType, } return dataset