From a4f1a72d2f8ff8f927296b273f5986beeb164f0f Mon Sep 17 00:00:00 2001
From: LinglongQian <38267728+LinglongQian@users.noreply.github.com>
Date: Tue, 8 Oct 2024 17:33:31 +0100
Subject: [PATCH] add csai test cases (#535)

---
 tests/classification/csai.py | 148 +++++++++++++++++++++++++++++++++++
 tests/imputation/csai.py     | 137 ++++++++++++++++++++++++++++++++
 2 files changed, 285 insertions(+)
 create mode 100644 tests/classification/csai.py
 create mode 100644 tests/imputation/csai.py

diff --git a/tests/classification/csai.py b/tests/classification/csai.py
new file mode 100644
index 00000000..916c6cf8
--- /dev/null
+++ b/tests/classification/csai.py
@@ -0,0 +1,148 @@
+"""
+Test cases for CSAI classification model.
+"""
+
+# Created by Linglong Qian <linglong.qian@kcl.ac.uk>
+# License: BSD-3-Clause
+
+import os
+import unittest
+
+import pytest
+
+from pypots.classification import CSAI
+from pypots.optim import Adam
+from pypots.utils.logging import logger
+from pypots.utils.metrics import calc_binary_classification_metrics
+from tests.global_test_config import (
+    DATA,
+    EPOCHS,
+    DEVICE,
+    TRAIN_SET,
+    VAL_SET,
+    TEST_SET,
+    GENERAL_H5_TRAIN_SET_PATH,
+    GENERAL_H5_VAL_SET_PATH,
+    GENERAL_H5_TEST_SET_PATH,
+    RESULT_SAVING_DIR_FOR_CLASSIFICATION,
+    check_tb_and_model_checkpoints_existence,
+)
+
+
+class TestCSAI(unittest.TestCase):
+    logger.info("Running tests for a classification model CSAI...")
+
+    # Set the log and model saving path
+    saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "CSAI")
+    model_save_name = "saved_CSAI_model.pypots"
+
+    # Initialize an Adam optimizer
+    optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+    # Initialize the CSAI model for classification
+    csai = CSAI(
+        n_steps=DATA["n_steps"],
+        n_features=DATA["n_features"],
+        n_classes=DATA["n_classes"],
+        rnn_hidden_size=32,
+        imputation_weight=0.7,
+        consistency_weight=0.3,
+        classification_weight=1.0,
+        removal_percent=10,
+        increase_factor=0.1,
+        compute_intervals=True,
+        step_channels=16,
+        batch_size=64,
+        epochs=EPOCHS,
+        dropout=0.5,
+        optimizer=optimizer,
+        num_workers=4,
+        device=DEVICE,
+        saving_path=saving_path,
+        model_saving_strategy="better",
+        verbose=True,
+    )
+
+    @pytest.mark.xdist_group(name="classification-csai")
+    def test_0_fit(self):
+        # Fit the CSAI model on the training and validation datasets
+        self.csai.fit(TRAIN_SET, VAL_SET)
+
+    @pytest.mark.xdist_group(name="classification-csai")
+    def test_1_classify(self):
+        # Classify test set using the trained CSAI model
+        results = self.csai.classify(TEST_SET)
+        
+        # Calculate binary classification metrics
+        metrics = calc_binary_classification_metrics(
+            results, DATA["test_y"]
+        )
+        
+        logger.info(
+            f'CSAI ROC_AUC: {metrics["roc_auc"]}, '
+            f'PR_AUC: {metrics["pr_auc"]}, '
+            f'F1: {metrics["f1"]}, '
+            f'Precision: {metrics["precision"]}, '
+            f'Recall: {metrics["recall"]}'
+        )
+        
+        assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"
+
+    @pytest.mark.xdist_group(name="classification-csai")
+    def test_2_parameters(self):
+        # Ensure that CSAI model parameters are properly initialized and trained
+        assert hasattr(self.csai, "model") and self.csai.model is not None
+
+        assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None
+
+        assert hasattr(self.csai, "best_loss")
+        self.assertNotEqual(self.csai.best_loss, float("inf"))
+
+        assert (
+            hasattr(self.csai, "best_model_dict")
+            and self.csai.best_model_dict is not None
+        )
+
+    @pytest.mark.xdist_group(name="classification-csai")
+    def test_3_saving_path(self):
+        # Ensure the root saving directory exists
+        assert os.path.exists(
+            self.saving_path
+        ), f"file {self.saving_path} does not exist"
+
+        # Check if the tensorboard file and model checkpoints exist
+        check_tb_and_model_checkpoints_existence(self.csai)
+
+        # Save the trained model to file, and verify the file existence
+        saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+        self.csai.save(saved_model_path)
+
+        # Test loading the saved model
+        self.csai.load(saved_model_path)
+
+    @pytest.mark.xdist_group(name="classification-csai")
+    def test_4_lazy_loading(self):
+        # Fit the CSAI model using lazy-loading datasets from H5 files
+        self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
+        
+        # Perform classification using lazy-loaded data
+        results = self.csai.classify(GENERAL_H5_TEST_SET_PATH)
+        
+        # Calculate binary classification metrics
+        metrics = calc_binary_classification_metrics(
+            results, DATA["test_y"]
+        )
+        
+        logger.info(
+            f'Lazy-loading CSAI ROC_AUC: {metrics["roc_auc"]}, '
+            f'PR_AUC: {metrics["pr_auc"]}, '
+            f'F1: {metrics["f1"]}, '
+            f'Precision: {metrics["precision"]}, '
+            f'Recall: {metrics["recall"]}'
+        )
+        
+        assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/tests/imputation/csai.py b/tests/imputation/csai.py
new file mode 100644
index 00000000..492a0a52
--- /dev/null
+++ b/tests/imputation/csai.py
@@ -0,0 +1,137 @@
+"""
+Test cases for CSAI imputation model.
+"""
+
+# Created by Linglong Qian <linglong.qian@kcl.ac.uk>
+# License: BSD-3-Clause
+
+
+import os.path
+import unittest
+
+import numpy as np
+import pytest
+
+from pypots.imputation import CSAI
+from pypots.optim import Adam
+from pypots.utils.logging import logger
+from pypots.utils.metrics import calc_mse
+from tests.global_test_config import (
+    DATA,
+    EPOCHS,
+    DEVICE,
+    TRAIN_SET,
+    VAL_SET,
+    TEST_SET,
+    GENERAL_H5_TRAIN_SET_PATH,
+    GENERAL_H5_VAL_SET_PATH,
+    GENERAL_H5_TEST_SET_PATH,
+    RESULT_SAVING_DIR_FOR_IMPUTATION,
+    check_tb_and_model_checkpoints_existence,
+)
+
+
+class TestCSAI(unittest.TestCase):
+    logger.info("Running tests for the CSAI imputation model...")
+
+    # Set the log and model saving path
+    saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "CSAI")
+    model_save_name = "saved_CSAI_model.pypots"
+
+    # Initialize an Adam optimizer
+    optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+    # Initialize the CSAI model
+    csai = CSAI(
+        n_steps=DATA["n_steps"],
+        n_features=DATA["n_features"],
+        rnn_hidden_size=32,
+        imputation_weight=0.7,
+        consistency_weight=0.3,
+        removal_percent=10,  # Assume we are removing 10% of the data
+        increase_factor=0.1,
+        compute_intervals=True,
+        step_channels=16,
+        batch_size=64,
+        epochs=EPOCHS,
+        optimizer=optimizer,
+        num_workers=0,
+        device=DEVICE,
+        saving_path=saving_path,
+        model_saving_strategy="best",
+        verbose=True,
+    )
+
+    @pytest.mark.xdist_group(name="imputation-csai")
+    def test_0_fit(self):
+        # Fit the CSAI model on the training and validation datasets
+        self.csai.fit(TRAIN_SET, VAL_SET)
+
+    @pytest.mark.xdist_group(name="imputation-csai")
+    def test_1_impute(self):
+        # Impute missing values using the trained CSAI model
+        imputed_X = self.csai.impute(TEST_SET)
+        assert not np.isnan(
+            imputed_X
+        ).any(), "Output still has missing values after running impute()."
+        
+        # Calculate mean squared error (MSE) for the test set
+        test_MSE = calc_mse(
+            imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"]
+        )
+        logger.info(f"CSAI test_MSE: {test_MSE}")
+
+    @pytest.mark.xdist_group(name="imputation-csai")
+    def test_2_parameters(self):
+        # Ensure that CSAI model parameters are properly initialized and trained
+        assert hasattr(self.csai, "model") and self.csai.model is not None
+
+        assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None
+
+        assert hasattr(self.csai, "best_loss")
+        self.assertNotEqual(self.csai.best_loss, float("inf"))
+
+        assert (
+            hasattr(self.csai, "best_model_dict")
+            and self.csai.best_model_dict is not None
+        )
+
+    @pytest.mark.xdist_group(name="imputation-csai")
+    def test_3_saving_path(self):
+        # Ensure the root saving directory exists
+        assert os.path.exists(
+            self.saving_path
+        ), f"file {self.saving_path} does not exist"
+
+        # Check if the tensorboard file and model checkpoints exist
+        check_tb_and_model_checkpoints_existence(self.csai)
+
+        # Save the trained model to file, and verify the file existence
+        saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+        self.csai.save(saved_model_path)
+
+        # Test loading the saved model
+        self.csai.load(saved_model_path)
+
+    @pytest.mark.xdist_group(name="imputation-csai")
+    def test_4_lazy_loading(self):
+        # Fit the CSAI model using lazy-loading datasets from H5 files
+        self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
+        
+        # Perform imputation using lazy-loaded data
+        imputation_results = self.csai.predict(GENERAL_H5_TEST_SET_PATH)
+        assert not np.isnan(
+            imputation_results["imputation"]
+        ).any(), "Output still has missing values after running impute()."
+
+        # Calculate the MSE on the test set
+        test_MSE = calc_mse(
+            imputation_results["imputation"],
+            DATA["test_X_ori"],
+            DATA["test_X_indicating_mask"],
+        )
+        logger.info(f"Lazy-loading CSAI test_MSE: {test_MSE}")
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file