Skip to content

Commit

Permalink
refactor: make patience optional, init KMeans in CRLI properly, and m…
Browse files Browse the repository at this point in the history
…ake VaDER raise error when training failed;
  • Loading branch information
WenjieDu committed Sep 23, 2023
1 parent f447b27 commit 8475c96
Show file tree
Hide file tree
Showing 14 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(
self,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(
n_classes: int,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
reconstruction_weight: float = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/grud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
rnn_hidden_size: int,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
static=False,
batch_size=32,
epochs=100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/clustering/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
n_clusters: int,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
4 changes: 3 additions & 1 deletion pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __init__(
n_steps, rnn_hidden_size * 2, n_features, decoder_fcn_output_dims, device
) # fully connected network is included in Decoder
self.kmeans = KMeans(
n_clusters=n_clusters
n_clusters=n_clusters,
n_init=10, # FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the
# value of `n_init` explicitly to suppress the warning.
) # TODO: implement KMean with torch for gpu acceleration

self.n_clusters = n_clusters
Expand Down
4 changes: 2 additions & 2 deletions pypots/clustering/vader/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
pretrain_epochs: int = 10,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down Expand Up @@ -462,7 +462,7 @@ def _train_model(
"Now quit to let you check your model training.\n"
"Please raise an issue https://github.com/WenjieDu/PyPOTS/issues if you have questions."
)
exit()
raise RuntimeError
else:
reg_covar *= 2

Expand Down
2 changes: 1 addition & 1 deletion pypots/forecasting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
# n_forecasting_steps: int,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
self,
batch_size: int,
epochs: int,
patience: int,
patience: Optional[int] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def __init__(
rnn_hidden_size: int,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/gpvae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(
window_size: int = 3,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/mrnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
rnn_hidden_size: int,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
MIT_weight: int = 1,
batch_size: int = 32,
epochs: int = 100,
patience: int = None,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand Down

0 comments on commit 8475c96

Please sign in to comment.