Skip to content

Commit

Permalink
refactor: update CSAI default arguments;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Oct 27, 2024
1 parent edd144d commit 81ea218
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 12 deletions.
6 changes: 3 additions & 3 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(
increase_factor: float,
compute_intervals: bool,
step_channels: int,
batch_size: int,
epochs: int,
dropout: float = 0.5,
patience: Union[int, None] = None,
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
optimizer: Optimizer = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def __init__(
increase_factor: float,
compute_intervals: bool,
step_channels: int,
batch_size: int,
epochs: int,
patience: Union[int, None] = None,
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Union[str, torch.device, list, None] = None,
Expand Down
4 changes: 1 addition & 3 deletions tests/classification/csai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,17 @@ class TestCSAI(unittest.TestCase):
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
n_classes=DATA["n_classes"],
rnn_hidden_size=32,
rnn_hidden_size=64,
imputation_weight=0.7,
consistency_weight=0.3,
classification_weight=1.0,
removal_percent=10,
increase_factor=0.1,
compute_intervals=True,
step_channels=16,
batch_size=64,
epochs=EPOCHS,
dropout=0.5,
optimizer=optimizer,
num_workers=4,
device=DEVICE,
saving_path=saving_path,
model_saving_strategy="better",
Expand Down
4 changes: 1 addition & 3 deletions tests/imputation/csai.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,15 @@ class TestCSAI(unittest.TestCase):
csai = CSAI(
n_steps=DATA["n_steps"],
n_features=DATA["n_features"],
rnn_hidden_size=32,
rnn_hidden_size=64,
imputation_weight=0.7,
consistency_weight=0.3,
removal_percent=10, # Assume we are removing 10% of the data
increase_factor=0.1,
compute_intervals=True,
step_channels=16,
batch_size=64,
epochs=EPOCHS,
optimizer=optimizer,
num_workers=0,
device=DEVICE,
saving_path=saving_path,
model_saving_strategy="best",
Expand Down

0 comments on commit 81ea218

Please sign in to comment.