diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index c65a3724..3419c5bb 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -116,10 +116,10 @@ def __init__( increase_factor: float, compute_intervals: bool, step_channels: int, - batch_size: int, - epochs: int, dropout: float = 0.5, - patience: Union[int, None] = None, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, optimizer: Optimizer = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index fe655ea2..a579fd2c 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -116,9 +116,9 @@ def __init__( increase_factor: float, compute_intervals: bool, step_channels: int, - batch_size: int, - epochs: int, - patience: Union[int, None] = None, + batch_size: int = 32, + epochs: int = 100, + patience: Optional[int] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Union[str, torch.device, list, None] = None, diff --git a/tests/classification/csai.py b/tests/classification/csai.py index 17f1028f..4a2bbf5f 100644 --- a/tests/classification/csai.py +++ b/tests/classification/csai.py @@ -44,7 +44,7 @@ class TestCSAI(unittest.TestCase): n_steps=DATA["n_steps"], n_features=DATA["n_features"], n_classes=DATA["n_classes"], - rnn_hidden_size=32, + rnn_hidden_size=64, imputation_weight=0.7, consistency_weight=0.3, classification_weight=1.0, @@ -52,11 +52,9 @@ class TestCSAI(unittest.TestCase): increase_factor=0.1, compute_intervals=True, step_channels=16, - batch_size=64, epochs=EPOCHS, dropout=0.5, optimizer=optimizer, - num_workers=4, device=DEVICE, saving_path=saving_path, model_saving_strategy="better", diff --git a/tests/imputation/csai.py b/tests/imputation/csai.py index 986657e9..f5c4873b 100644 --- a/tests/imputation/csai.py +++ b/tests/imputation/csai.py @@ -45,17 +45,15 @@ class TestCSAI(unittest.TestCase): csai = CSAI( n_steps=DATA["n_steps"], n_features=DATA["n_features"], - rnn_hidden_size=32, + rnn_hidden_size=64, imputation_weight=0.7, consistency_weight=0.3, removal_percent=10, # Assume we are removing 10% of the data increase_factor=0.1, compute_intervals=True, step_channels=16, - batch_size=64, epochs=EPOCHS, optimizer=optimizer, - num_workers=0, device=DEVICE, saving_path=saving_path, model_saving_strategy="best",