diff --git a/README.md b/README.md index e024cc27..fcfda433 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,6 @@ The paper references and links are all listed at the bottom of this file. | Neural Net | iTransformer🧑‍🔧[^24] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ModernTCN[^38] | ✅ | | | | | `2024 - ICLR` | | Neural Net | ImputeFormer🧑‍🔧[^34] | ✅ | | | | | `2024 - KDD` | -| Neural Net | SegRNN[^42] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SAITS[^1] | ✅ | | | | | `2023 - ESWA` | | Neural Net | FreTS🧑‍🔧[^23] | ✅ | | | | | `2023 - NeurIPS` | | Neural Net | Koopa🧑‍🔧[^29] | ✅ | | | | | `2023 - NeurIPS` | @@ -137,6 +136,8 @@ The paper references and links are all listed at the bottom of this file. | Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | | Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | | Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | CSAI[^42] | ✅ | | | | | `2023 - arXiv` | +| Neural Net | SegRNN🧑‍🔧[^43] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | @@ -510,6 +511,9 @@ Time-Series.AI [^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). [FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). *ICLR 2024*. -[^42]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). -[Segrnn: Segment recurrent neural network for long-term time series forecasting](https://github.com/lss-1138/SegRNN) +[^42]: Qian, L., Ibrahim, Z., Ellis, H. L., Zhang, A., Zhang, Y., Wang, T., & Dobson, R. (2023). +[Knowledge Enhanced Conditional Imputation for Healthcare Time-series](https://arxiv.org/abs/2312.16713). +*arXiv 2023*. +[^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). +[SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200). *arXiv 2023*. diff --git a/README_zh.md b/README_zh.md index 55978e01..d85af8ba 100644 --- a/README_zh.md +++ b/README_zh.md @@ -121,6 +121,8 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异 | Neural Net | MICN🧑‍🔧[^27] | ✅ | | | | | `2023 - ICLR` | | Neural Net | DLinear🧑‍🔧[^17] | ✅ | | | | | `2023 - AAAI` | | Neural Net | TiDE🧑‍🔧[^28] | ✅ | | | | | `2023 - TMLR` | +| Neural Net | CSAI[^42] | ✅ | | | | | `2023 - arXiv` | +| Neural Net | SegRNN🧑‍🔧[^43] | ✅ | | | | | `2023 - arXiv` | | Neural Net | SCINet🧑‍🔧[^30] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | Nonstationary Tr.🧑‍🔧[^25] | ✅ | | | | | `2022 - NeurIPS` | | Neural Net | FiLM🧑‍🔧[^22] | ✅ | | | | | `2022 - NeurIPS` | @@ -482,3 +484,9 @@ Time-Series.AI [^41]: Xu, Z., Zeng, A., & Xu, Q. (2024). [FITS: Modeling Time Series with 10k parameters](https://openreview.net/forum?id=bWcnvZ3qMb). *ICLR 2024*. +[^42]: Qian, L., Ibrahim, Z., Ellis, H. L., Zhang, A., Zhang, Y., Wang, T., & Dobson, R. (2023). +[Knowledge Enhanced Conditional Imputation for Healthcare Time-series](https://arxiv.org/abs/2312.16713). +*arXiv 2023*. +[^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). +[SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200). +*arXiv 2023*. diff --git a/docs/index.rst b/docs/index.rst index 4aaeb762..fd3b9219 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -165,6 +165,10 @@ The paper references are all listed at the bottom of this readme file. +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | TiDE🧑‍🔧 :cite:`das2023tide` | ✅ | | | | | ``2023 - TMLR`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | CSAI :cite:`qian2023csai` | ✅ | | | | | ``2023 - arXiv`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| Neural Net | SegRNN🧑‍🔧 :cite:`lin2023segrnn` | ✅ | | | | | ``2023 - arXiv`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | SCINet🧑‍🔧 :cite:`liu2022scinet` | ✅ | | | | | ``2022 - NeurIPS`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | Nonstationary Tr🧑‍🔧 :cite:`liu2022nonstationary` | ✅ | | | | | ``2022 - NeurIPS`` | diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index a7b47b07..73e55646 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -28,6 +28,24 @@ pypots.imputation.tefn :show-inheritance: :inherited-members: +pypots.imputation.csai +------------------------------------ + +.. automodule:: pypots.imputation.csai + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + +pypots.imputation.segrnn +------------------------------------ + +.. automodule:: pypots.imputation.segrnn + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.fits ------------------------------------ diff --git a/docs/references.bib b/docs/references.bib index ce0014ea..10a9632c 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -779,3 +779,9 @@ @inproceedings{jin2024timellm url={https://openreview.net/forum?id=Unb5CVPtae} } +@article{qian2023csai, +title={Knowledge Enhanced Conditional Imputation for Healthcare Time-series}, +author={Qian, Linglong and Ibrahim, Zina and Ellis, Hugh Logan and Zhang, Ao and Zhang, Yuezhou and Wang, Tao and Dobson, Richard}, +journal={arXiv preprint arXiv:2312.16713}, +year={2023} +} \ No newline at end of file diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index b61286c5..010561a2 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -21,7 +21,8 @@ class CSAI(BaseNNImputer): - """ + """The PyTorch implementation of the CSAI model :cite:`qian2023csai`. + Parameters ---------- n_steps : @@ -58,29 +59,48 @@ class CSAI(BaseNNImputer): The number of epochs for training the model. patience : - The patience for the early-stopping mechanism. Given a positive integer, training will stop when no improvement is observed after the specified number of epochs. If set to None, early-stopping is disabled. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. optimizer : - The optimizer used for model training. Defaults to the Adam optimizer if not specified. + The optimizer for model training. + If not given, will use a default Adam optimizer. num_workers : - The number of subprocesses used for data loading. Setting this to `0` means that data loading is performed in the main process without using subprocesses. + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. device : - The device for the model to run on, which can be a string, a :class:`torch.device` object, or a list of devices. If not provided, the model will attempt to use available CUDA devices first, then default to CPUs. + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. saving_path : - The path for saving model checkpoints and tensorboard files during training. If not provided, models will not be saved automatically. + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. model_saving_strategy : - The strategy for saving model checkpoints. Can be one of [None, "best", "better", "all"]. "best" saves the best model after training, "better" saves any model that improves during training, and "all" saves models after each epoch. If set to None, no models will be saved. + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. verbose : - Whether to print training logs during the training process. + Whether to print out the training logs during the training process. Notes ----- - CSAI (Consistent Sequential Imputation) is a bidirectional model designed for time-series imputation. It employs a forward and backward GRU network to handle missing data, using consistency and reconstruction losses to improve accuracy. The model supports various training configurations, such as interval computations, early-stopping, and multiple devices for training. Results can be saved based on the specified saving strategy, and tensorboard files are generated for tracking the model's performance over time. + CSAI (Consistent Sequential Imputation) is a bidirectional model designed for time-series imputation. + It employs a forward and backward GRU network to handle missing data, using consistency and reconstruction losses + to improve accuracy. The model supports various training configurations, such as interval computations, + early-stopping, and multiple devices for training. Results can be saved based on the specified saving strategy, + and tensorboard files are generated for tracking the model's performance over time. """