From 18945e8fb6241e8ac9436e82fae1e951bfc28c11 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Sep 2023 21:32:57 +0800 Subject: [PATCH] fix: turn missing_mask into torch.float; --- pypots/data/base.py | 4 ++-- pypots/imputation/gpvae/data.py | 2 +- tests/imputation/gpvae.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pypots/data/base.py b/pypots/data/base.py index a71d4014..1bef9f9c 100644 --- a/pypots/data/base.py +++ b/pypots/data/base.py @@ -205,7 +205,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: """ X = self.X[idx].to(torch.float32) - missing_mask = ~torch.isnan(X) + missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) sample = [ torch.tensor(idx), @@ -280,7 +280,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: self.file_handle = self._open_file_handle() X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) - missing_mask = ~torch.isnan(X) + missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) sample = [ torch.tensor(idx), diff --git a/pypots/imputation/gpvae/data.py b/pypots/imputation/gpvae/data.py index b7cd6fc5..8bb9be8c 100644 --- a/pypots/imputation/gpvae/data.py +++ b/pypots/imputation/gpvae/data.py @@ -116,7 +116,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: self.file_handle = self._open_file_handle() X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) - missing_mask = ~torch.isnan(X) + missing_mask = (~torch.isnan(X)).to(torch.float32) X = torch.nan_to_num(X) sample = [ diff --git a/tests/imputation/gpvae.py b/tests/imputation/gpvae.py index 07dfeab2..9c59c5b2 100644 --- a/tests/imputation/gpvae.py +++ b/tests/imputation/gpvae.py @@ -80,7 +80,7 @@ def test_2_parameters(self): and self.gp_vae.best_model_dict is not None ) - @pytest.mark.xdist_group(name="imputation-GPVAE") + @pytest.mark.xdist_group(name="imputation-gpvae") def test_3_saving_path(self): # whether the root saving dir exists, which should be created by save_log_into_tb_file assert os.path.exists(