diff --git a/pypots/classification/csai/data.py b/pypots/classification/csai/data.py index cd829882..3b93765c 100644 --- a/pypots/classification/csai/data.py +++ b/pypots/classification/csai/data.py @@ -6,6 +6,7 @@ # License: BSD-3-Clause from typing import Union + from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index a504cc1d..c65a3724 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -13,8 +13,11 @@ from .core import _BCSAI from .data import DatasetForCSAI from ..base import BaseNNClassifier +from ...data.checking import key_in_data_set +from ...data.saving.h5 import load_dict_from_h5 from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class CSAI(BaseNNClassifier): @@ -171,6 +174,7 @@ def __init__( # set up the optimizer self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) def _assemble_input_for_training(self, data: list, training=True) -> dict: # extract data @@ -245,6 +249,12 @@ def fit( file_type: str = "hdf5", ) -> None: # Create dataset + if isinstance(train_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole train set will be loaded into memory." + ) + train_set = load_dict_from_h5(train_set) training_set = DatasetForCSAI( data=train_set, file_type=file_type, @@ -267,6 +277,15 @@ def fit( ) val_loader = None if val_set is not None: + if isinstance(val_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole val set will be loaded into memory." + ) + val_set = load_dict_from_h5(val_set) + + if not key_in_data_set("X_ori", val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") val_set = DatasetForCSAI( data=val_set, file_type=file_type, @@ -284,24 +303,6 @@ def fit( shuffle=False, num_workers=self.num_workers, ) - # Create model - self.model = _BCSAI( - n_steps=self.n_steps, - n_features=self.n_features, - rnn_hidden_size=self.rnn_hidden_size, - imputation_weight=self.imputation_weight, - consistency_weight=self.consistency_weight, - classification_weight=self.classification_weight, - n_classes=self.n_classes, - step_channels=self.step_channels, - dropout=self.dropout, - intervals=self.intervals, - ) - self._send_model_to_given_device() - self._print_model_size() - - # set up the optimizer - self.optimizer.init_optimizer(self.model.parameters()) # train the model self._train_model(train_loader, val_loader) @@ -317,6 +318,13 @@ def predict( ) -> dict: self.model.eval() + + if isinstance(test_set, str): + logger.warning( + "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " + "Hence the whole test set will be loaded into memory." + ) + test_set = load_dict_from_h5(test_set) test_set = DatasetForCSAI( data=test_set, file_type=file_type, diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index 4eaab839..fe655ea2 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -6,7 +6,6 @@ # License: BSD-3-Clause from typing import Union, Optional -from venv import logger import numpy as np import torch @@ -19,6 +18,7 @@ from ...data.saving.h5 import load_dict_from_h5 from ...optim.adam import Adam from ...optim.base import Optimizer +from ...utils.logging import logger class CSAI(BaseNNImputer):