Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add csai test cases #535

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions tests/classification/csai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Test cases for CSAI classification model.
"""

# Created by Linglong Qian <[email protected]>
# License: BSD-3-Clause

import os
import unittest

import pytest

from pypots.classification import CSAI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_binary_classification_metrics
from tests.global_test_config import (
DATA,
EPOCHS,
DEVICE,
TRAIN_SET,
VAL_SET,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
RESULT_SAVING_DIR_FOR_CLASSIFICATION,
check_tb_and_model_checkpoints_existence,
)


class TestCSAI(unittest.TestCase):
logger.info("Running tests for a classification model CSAI...")

# Set the log and model saving path
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "CSAI")
model_save_name = "saved_CSAI_model.pypots"

# Initialize an Adam optimizer
optimizer = Adam(lr=0.001, weight_decay=1e-5)

# Initialize the CSAI model for classification
csai = CSAI(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
n_classes=DATA["n_classes"],
rnn_hidden_size=32,
imputation_weight=0.7,
consistency_weight=0.3,
classification_weight=1.0,
removal_percent=10,
increase_factor=0.1,
compute_intervals=True,
step_channels=16,
batch_size=64,
epochs=EPOCHS,
dropout=0.5,
optimizer=optimizer,
num_workers=4,
device=DEVICE,
saving_path=saving_path,
model_saving_strategy="better",
verbose=True,
)

@pytest.mark.xdist_group(name="classification-csai")
def test_0_fit(self):
# Fit the CSAI model on the training and validation datasets
self.csai.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="classification-csai")
def test_1_classify(self):
# Classify test set using the trained CSAI model
results = self.csai.classify(TEST_SET)

# Calculate binary classification metrics
metrics = calc_binary_classification_metrics(
results, DATA["test_y"]
)

logger.info(
f'CSAI ROC_AUC: {metrics["roc_auc"]}, '
f'PR_AUC: {metrics["pr_auc"]}, '
f'F1: {metrics["f1"]}, '
f'Precision: {metrics["precision"]}, '
f'Recall: {metrics["recall"]}'
)

assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"

@pytest.mark.xdist_group(name="classification-csai")
def test_2_parameters(self):
# Ensure that CSAI model parameters are properly initialized and trained
assert hasattr(self.csai, "model") and self.csai.model is not None

assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None

assert hasattr(self.csai, "best_loss")
self.assertNotEqual(self.csai.best_loss, float("inf"))

assert (
hasattr(self.csai, "best_model_dict")
and self.csai.best_model_dict is not None
)

@pytest.mark.xdist_group(name="classification-csai")
def test_3_saving_path(self):
# Ensure the root saving directory exists
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# Check if the tensorboard file and model checkpoints exist
check_tb_and_model_checkpoints_existence(self.csai)

# Save the trained model to file, and verify the file existence
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csai.save(saved_model_path)

# Test loading the saved model
self.csai.load(saved_model_path)

@pytest.mark.xdist_group(name="classification-csai")
def test_4_lazy_loading(self):
# Fit the CSAI model using lazy-loading datasets from H5 files
self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)

# Perform classification using lazy-loaded data
results = self.csai.classify(GENERAL_H5_TEST_SET_PATH)

# Calculate binary classification metrics
metrics = calc_binary_classification_metrics(
results, DATA["test_y"]
)

logger.info(
f'Lazy-loading CSAI ROC_AUC: {metrics["roc_auc"]}, '
f'PR_AUC: {metrics["pr_auc"]}, '
f'F1: {metrics["f1"]}, '
f'Precision: {metrics["precision"]}, '
f'Recall: {metrics["recall"]}'
)

assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"


if __name__ == "__main__":
unittest.main()
137 changes: 137 additions & 0 deletions tests/imputation/csai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Test cases for CSAI imputation model.
"""

# Created by Linglong Qian <[email protected]>
# License: BSD-3-Clause


import os.path
import unittest

import numpy as np
import pytest

from pypots.imputation import CSAI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_mse
from tests.global_test_config import (
DATA,
EPOCHS,
DEVICE,
TRAIN_SET,
VAL_SET,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
RESULT_SAVING_DIR_FOR_IMPUTATION,
check_tb_and_model_checkpoints_existence,
)


class TestCSAI(unittest.TestCase):
logger.info("Running tests for the CSAI imputation model...")

# Set the log and model saving path
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "CSAI")
model_save_name = "saved_CSAI_model.pypots"

# Initialize an Adam optimizer
optimizer = Adam(lr=0.001, weight_decay=1e-5)

# Initialize the CSAI model
csai = CSAI(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
rnn_hidden_size=32,
imputation_weight=0.7,
consistency_weight=0.3,
removal_percent=10, # Assume we are removing 10% of the data
increase_factor=0.1,
compute_intervals=True,
step_channels=16,
batch_size=64,
epochs=EPOCHS,
optimizer=optimizer,
num_workers=0,
device=DEVICE,
saving_path=saving_path,
model_saving_strategy="best",
verbose=True,
)

@pytest.mark.xdist_group(name="imputation-csai")
def test_0_fit(self):
# Fit the CSAI model on the training and validation datasets
self.csai.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="imputation-csai")
def test_1_impute(self):
# Impute missing values using the trained CSAI model
imputed_X = self.csai.impute(TEST_SET)
assert not np.isnan(
imputed_X
).any(), "Output still has missing values after running impute()."

# Calculate mean squared error (MSE) for the test set
test_MSE = calc_mse(
imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"]
)
logger.info(f"CSAI test_MSE: {test_MSE}")

@pytest.mark.xdist_group(name="imputation-csai")
def test_2_parameters(self):
# Ensure that CSAI model parameters are properly initialized and trained
assert hasattr(self.csai, "model") and self.csai.model is not None

assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None

assert hasattr(self.csai, "best_loss")
self.assertNotEqual(self.csai.best_loss, float("inf"))

assert (
hasattr(self.csai, "best_model_dict")
and self.csai.best_model_dict is not None
)

@pytest.mark.xdist_group(name="imputation-csai")
def test_3_saving_path(self):
# Ensure the root saving directory exists
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# Check if the tensorboard file and model checkpoints exist
check_tb_and_model_checkpoints_existence(self.csai)

# Save the trained model to file, and verify the file existence
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csai.save(saved_model_path)

# Test loading the saved model
self.csai.load(saved_model_path)

@pytest.mark.xdist_group(name="imputation-csai")
def test_4_lazy_loading(self):
# Fit the CSAI model using lazy-loading datasets from H5 files
self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)

# Perform imputation using lazy-loaded data
imputation_results = self.csai.predict(GENERAL_H5_TEST_SET_PATH)
assert not np.isnan(
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."

# Calculate the MSE on the test set
test_MSE = calc_mse(
imputation_results["imputation"],
DATA["test_X_ori"],
DATA["test_X_indicating_mask"],
)
logger.info(f"Lazy-loading CSAI test_MSE: {test_MSE}")


if __name__ == "__main__":
unittest.main()