Skip to content

Commit

Permalink
feat: return ICUType in the physionet_2012 dataset;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 9, 2023
1 parent 2dd0857 commit 60b42c9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
14 changes: 14 additions & 0 deletions pypots/data/generating.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1):
dataset = load_specific_dataset("physionet_2012")
X = dataset["X"]
y = dataset["y"]
ICUType = dataset["ICUType"]

all_recordID = X["RecordID"].unique()
train_set_ids, test_set_ids = train_test_split(all_recordID, test_size=0.2)
train_set_ids, val_set_ids = train_test_split(train_set_ids, test_size=0.2)
Expand Down Expand Up @@ -385,16 +387,28 @@ def gene_physionet2012(artificially_missing_rate: float = 0.1):
test_y = y[y.index.isin(test_set_ids)].sort_index()
train_y, val_y, test_y = train_y.to_numpy(), val_y.to_numpy(), test_y.to_numpy()

train_ICUType = ICUType[ICUType.index.isin(train_set_ids)].sort_index()
val_ICUType = ICUType[ICUType.index.isin(val_set_ids)].sort_index()
test_ICUType = ICUType[ICUType.index.isin(test_set_ids)].sort_index()
train_ICUType, val_ICUType, test_ICUType = (
train_ICUType.to_numpy(),
val_ICUType.to_numpy(),
test_ICUType.to_numpy(),
)

data = {
"n_classes": 2,
"n_steps": 48,
"n_features": train_X.shape[-1],
"train_X": train_X,
"train_y": train_y.flatten(),
"train_ICUType": train_ICUType.flatten(),
"val_X": val_X,
"val_y": val_y.flatten(),
"val_ICUType": val_ICUType.flatten(),
"test_X": test_X,
"test_y": test_y.flatten(),
"test_ICUType": test_ICUType.flatten(),
"scaler": scaler,
}

Expand Down
7 changes: 5 additions & 2 deletions pypots/data/load_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def preprocess_physionet2012(data: dict) -> dict:
y : pandas.Series
The 11988 classification labels of all patients, indicating whether they were deceased.
"""
# remove the static features, e.g. age, gender
data["static_features"].remove("ICUType") # keep ICUType for now
# remove the other static features, e.g. age, gender
X = data["X"].drop(data["static_features"], axis=1)

def apply_func(df_temp): # pad and truncate to set the max length of samples as 48
Expand All @@ -41,11 +42,13 @@ def apply_func(df_temp): # pad and truncate to set the max length of samples as
X = X.groupby("RecordID").apply(apply_func)
X = X.drop("RecordID", axis=1)
X = X.reset_index()
X = X.drop(["level_1"], axis=1)
ICUType = X[["RecordID", "ICUType"]].set_index("RecordID").dropna()
X = X.drop(["level_1", "ICUType"], axis=1)

dataset = {
"X": X,
"y": data["y"],
"ICUType": ICUType,
}

return dataset

0 comments on commit 60b42c9

Please sign in to comment.