Skip to content

Commit

Permalink
fix: lazy loading error for classification CSAI;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 27, 2024
1 parent b6a3280 commit edd144d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
1 change: 1 addition & 0 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# License: BSD-3-Clause

from typing import Union

from ...imputation.csai.data import DatasetForCSAI as DatasetForCSAI_Imputation


Expand Down
44 changes: 26 additions & 18 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from .core import _BCSAI
from .data import DatasetForCSAI
from ..base import BaseNNClassifier
from ...data.checking import key_in_data_set
from ...data.saving.h5 import load_dict_from_h5
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class CSAI(BaseNNClassifier):
Expand Down Expand Up @@ -171,6 +174,7 @@ def __init__(

# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: list, training=True) -> dict:
# extract data
Expand Down Expand Up @@ -245,6 +249,12 @@ def fit(
file_type: str = "hdf5",
) -> None:
# Create dataset
if isinstance(train_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole train set will be loaded into memory."
)
train_set = load_dict_from_h5(train_set)
training_set = DatasetForCSAI(
data=train_set,
file_type=file_type,
Expand All @@ -267,6 +277,15 @@ def fit(
)
val_loader = None
if val_set is not None:
if isinstance(val_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole val set will be loaded into memory."
)
val_set = load_dict_from_h5(val_set)

if not key_in_data_set("X_ori", val_set):
raise ValueError("val_set must contain 'X_ori' for model validation.")
val_set = DatasetForCSAI(
data=val_set,
file_type=file_type,
Expand All @@ -284,24 +303,6 @@ def fit(
shuffle=False,
num_workers=self.num_workers,
)
# Create model
self.model = _BCSAI(
n_steps=self.n_steps,
n_features=self.n_features,
rnn_hidden_size=self.rnn_hidden_size,
imputation_weight=self.imputation_weight,
consistency_weight=self.consistency_weight,
classification_weight=self.classification_weight,
n_classes=self.n_classes,
step_channels=self.step_channels,
dropout=self.dropout,
intervals=self.intervals,
)
self._send_model_to_given_device()
self._print_model_size()

# set up the optimizer
self.optimizer.init_optimizer(self.model.parameters())

# train the model
self._train_model(train_loader, val_loader)
Expand All @@ -317,6 +318,13 @@ def predict(
) -> dict:

self.model.eval()

if isinstance(test_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole test set will be loaded into memory."
)
test_set = load_dict_from_h5(test_set)
test_set = DatasetForCSAI(
data=test_set,
file_type=file_type,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# License: BSD-3-Clause

from typing import Union, Optional
from venv import logger

import numpy as np
import torch
Expand All @@ -19,6 +18,7 @@
from ...data.saving.h5 import load_dict_from_h5
from ...optim.adam import Adam
from ...optim.base import Optimizer
from ...utils.logging import logger


class CSAI(BaseNNImputer):
Expand Down

0 comments on commit edd144d

Please sign in to comment.