diff --git a/pypots/imputation/tefn/model.py b/pypots/imputation/tefn/model.py index e97b6f4e..c7912164 100644 --- a/pypots/imputation/tefn/model.py +++ b/pypots/imputation/tefn/model.py @@ -39,7 +39,8 @@ class TEFN(BaseNNImputer): apply_nonstationary_norm : Whether to apply non-stationary normalization to the input data for TimesNet. Please refer to :cite:`liu2022nonstationary` for details about non-stationary normalization, - which is not the idea of the original TimesNet paper. Hence, we make it optional and default not to use here. + which is not the idea of the original TimesNet paper. Hence, we make it optional + and default not to use here. batch_size : The batch size for training and evaluating the model. @@ -64,9 +65,9 @@ class TEFN(BaseNNImputer): The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. - If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the - model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). - Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , + the model will be parallely trained on the multiple devices (so far only support parallel training on CUDA + devices). Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during