From dbcc3c545ee01b3dae11a1a02656aac2b36c5a7d Mon Sep 17 00:00:00 2001 From: Luigi Bonati Date: Tue, 11 Jun 2024 22:42:44 +0200 Subject: [PATCH] fixed device in tests --- mlcolvar/tests/test_cvs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlcolvar/tests/test_cvs.py b/mlcolvar/tests/test_cvs.py index a76c1a99..20d09c93 100644 --- a/mlcolvar/tests/test_cvs.py +++ b/mlcolvar/tests/test_cvs.py @@ -74,6 +74,7 @@ def dataset(): def test_resume_from_checkpoint(cv_model, dataset): """CVs correctly resume from a checkpoint.""" datamodule = DictModule(dataset, lengths=[1.0,0.], batch_size=len(dataset)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Run a few steps of training in a temporary directory. with tempfile.TemporaryDirectory() as tmp_dir_path: @@ -94,7 +95,7 @@ def test_resume_from_checkpoint(cv_model, dataset): cv_model2 = cv_model.__class__.load_from_checkpoint(checkpoint_file_path) # Check that state is the same. - x = dataset['data'] - cv_model.eval() - cv_model2.eval() + x = dataset['data'].to(device) + cv_model.to(device).eval() + cv_model2.to(device).eval() assert torch.allclose(cv_model(x), cv_model2(x))