Skip to content

Commit

Permalink
add csai test cases (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
LinglongQian authored Oct 8, 2024
1 parent 6c5777e commit a4f1a72
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 0 deletions.
148 changes: 148 additions & 0 deletions tests/classification/csai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Test cases for CSAI classification model.
"""

# Created by Linglong Qian <[email protected]>
# License: BSD-3-Clause

import os
import unittest

import pytest

from pypots.classification import CSAI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_binary_classification_metrics
from tests.global_test_config import (
DATA,
EPOCHS,
DEVICE,
TRAIN_SET,
VAL_SET,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
RESULT_SAVING_DIR_FOR_CLASSIFICATION,
check_tb_and_model_checkpoints_existence,
)


class TestCSAI(unittest.TestCase):
logger.info("Running tests for a classification model CSAI...")

# Set the log and model saving path
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "CSAI")
model_save_name = "saved_CSAI_model.pypots"

# Initialize an Adam optimizer
optimizer = Adam(lr=0.001, weight_decay=1e-5)

# Initialize the CSAI model for classification
csai = CSAI(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
n_classes=DATA["n_classes"],
rnn_hidden_size=32,
imputation_weight=0.7,
consistency_weight=0.3,
classification_weight=1.0,
removal_percent=10,
increase_factor=0.1,
compute_intervals=True,
step_channels=16,
batch_size=64,
epochs=EPOCHS,
dropout=0.5,
optimizer=optimizer,
num_workers=4,
device=DEVICE,
saving_path=saving_path,
model_saving_strategy="better",
verbose=True,
)

@pytest.mark.xdist_group(name="classification-csai")
def test_0_fit(self):
# Fit the CSAI model on the training and validation datasets
self.csai.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="classification-csai")
def test_1_classify(self):
# Classify test set using the trained CSAI model
results = self.csai.classify(TEST_SET)

# Calculate binary classification metrics
metrics = calc_binary_classification_metrics(
results, DATA["test_y"]
)

logger.info(
f'CSAI ROC_AUC: {metrics["roc_auc"]}, '
f'PR_AUC: {metrics["pr_auc"]}, '
f'F1: {metrics["f1"]}, '
f'Precision: {metrics["precision"]}, '
f'Recall: {metrics["recall"]}'
)

assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"

@pytest.mark.xdist_group(name="classification-csai")
def test_2_parameters(self):
# Ensure that CSAI model parameters are properly initialized and trained
assert hasattr(self.csai, "model") and self.csai.model is not None

assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None

assert hasattr(self.csai, "best_loss")
self.assertNotEqual(self.csai.best_loss, float("inf"))

assert (
hasattr(self.csai, "best_model_dict")
and self.csai.best_model_dict is not None
)

@pytest.mark.xdist_group(name="classification-csai")
def test_3_saving_path(self):
# Ensure the root saving directory exists
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# Check if the tensorboard file and model checkpoints exist
check_tb_and_model_checkpoints_existence(self.csai)

# Save the trained model to file, and verify the file existence
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csai.save(saved_model_path)

# Test loading the saved model
self.csai.load(saved_model_path)

@pytest.mark.xdist_group(name="classification-csai")
def test_4_lazy_loading(self):
# Fit the CSAI model using lazy-loading datasets from H5 files
self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)

# Perform classification using lazy-loaded data
results = self.csai.classify(GENERAL_H5_TEST_SET_PATH)

# Calculate binary classification metrics
metrics = calc_binary_classification_metrics(
results, DATA["test_y"]
)

logger.info(
f'Lazy-loading CSAI ROC_AUC: {metrics["roc_auc"]}, '
f'PR_AUC: {metrics["pr_auc"]}, '
f'F1: {metrics["f1"]}, '
f'Precision: {metrics["precision"]}, '
f'Recall: {metrics["recall"]}'
)

assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"


if __name__ == "__main__":
unittest.main()
137 changes: 137 additions & 0 deletions tests/imputation/csai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Test cases for CSAI imputation model.
"""

# Created by Linglong Qian <[email protected]>
# License: BSD-3-Clause


import os.path
import unittest

import numpy as np
import pytest

from pypots.imputation import CSAI
from pypots.optim import Adam
from pypots.utils.logging import logger
from pypots.utils.metrics import calc_mse
from tests.global_test_config import (
DATA,
EPOCHS,
DEVICE,
TRAIN_SET,
VAL_SET,
TEST_SET,
GENERAL_H5_TRAIN_SET_PATH,
GENERAL_H5_VAL_SET_PATH,
GENERAL_H5_TEST_SET_PATH,
RESULT_SAVING_DIR_FOR_IMPUTATION,
check_tb_and_model_checkpoints_existence,
)


class TestCSAI(unittest.TestCase):
logger.info("Running tests for the CSAI imputation model...")

# Set the log and model saving path
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "CSAI")
model_save_name = "saved_CSAI_model.pypots"

# Initialize an Adam optimizer
optimizer = Adam(lr=0.001, weight_decay=1e-5)

# Initialize the CSAI model
csai = CSAI(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
rnn_hidden_size=32,
imputation_weight=0.7,
consistency_weight=0.3,
removal_percent=10, # Assume we are removing 10% of the data
increase_factor=0.1,
compute_intervals=True,
step_channels=16,
batch_size=64,
epochs=EPOCHS,
optimizer=optimizer,
num_workers=0,
device=DEVICE,
saving_path=saving_path,
model_saving_strategy="best",
verbose=True,
)

@pytest.mark.xdist_group(name="imputation-csai")
def test_0_fit(self):
# Fit the CSAI model on the training and validation datasets
self.csai.fit(TRAIN_SET, VAL_SET)

@pytest.mark.xdist_group(name="imputation-csai")
def test_1_impute(self):
# Impute missing values using the trained CSAI model
imputed_X = self.csai.impute(TEST_SET)
assert not np.isnan(
imputed_X
).any(), "Output still has missing values after running impute()."

# Calculate mean squared error (MSE) for the test set
test_MSE = calc_mse(
imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"]
)
logger.info(f"CSAI test_MSE: {test_MSE}")

@pytest.mark.xdist_group(name="imputation-csai")
def test_2_parameters(self):
# Ensure that CSAI model parameters are properly initialized and trained
assert hasattr(self.csai, "model") and self.csai.model is not None

assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None

assert hasattr(self.csai, "best_loss")
self.assertNotEqual(self.csai.best_loss, float("inf"))

assert (
hasattr(self.csai, "best_model_dict")
and self.csai.best_model_dict is not None
)

@pytest.mark.xdist_group(name="imputation-csai")
def test_3_saving_path(self):
# Ensure the root saving directory exists
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# Check if the tensorboard file and model checkpoints exist
check_tb_and_model_checkpoints_existence(self.csai)

# Save the trained model to file, and verify the file existence
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csai.save(saved_model_path)

# Test loading the saved model
self.csai.load(saved_model_path)

@pytest.mark.xdist_group(name="imputation-csai")
def test_4_lazy_loading(self):
# Fit the CSAI model using lazy-loading datasets from H5 files
self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)

# Perform imputation using lazy-loaded data
imputation_results = self.csai.predict(GENERAL_H5_TEST_SET_PATH)
assert not np.isnan(
imputation_results["imputation"]
).any(), "Output still has missing values after running impute()."

# Calculate the MSE on the test set
test_MSE = calc_mse(
imputation_results["imputation"],
DATA["test_X_ori"],
DATA["test_X_indicating_mask"],
)
logger.info(f"Lazy-loading CSAI test_MSE: {test_MSE}")


if __name__ == "__main__":
unittest.main()

0 comments on commit a4f1a72

Please sign in to comment.