diff --git a/tests/myria3d/test_train_and_predict.py b/tests/myria3d/test_train_and_predict.py index 221b4af0..0ef0c388 100644 --- a/tests/myria3d/test_train_and_predict.py +++ b/tests/myria3d/test_train_and_predict.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from lightning.pytorch.accelerators import find_usable_cuda_devices + from myria3d.pctl.dataset.toy_dataset import TOY_LAS_DATA from myria3d.pctl.dataset.utils import pdal_read_las_array @@ -59,9 +61,7 @@ def test_FrenchLidar_RandLaNetDebug_with_gpu(toy_dataset_hdf5_path, tmpdir_facto tmp_paths_overrides = _make_list_of_necesary_hydra_overrides_with_tmp_paths( toy_dataset_hdf5_path, tmpdir ) - # We will always use the first GPU id for tests, because it always exists if there are some GPUs. - # Attention to concurrency with other processes using the GPU when running tests. - gpu_id = 0 + gpu_id = find_usable_cuda_devices(1) cfg_one_epoch = make_default_hydra_cfg( overrides=[ "experiment=RandLaNetDebug", @@ -216,7 +216,7 @@ def _run_test_right_after_training( tmp_paths_overrides = _make_list_of_necesary_hydra_overrides_with_tmp_paths( toy_dataset_hdf5_path, tmpdir ) - devices = "[0]" if accelerator == "gpu" else 1 + devices = find_usable_cuda_devices(1) if accelerator == "gpu" else 1 cfg_test_using_trained_model = make_default_hydra_cfg( overrides=[ "experiment=test", # sets task.task_name to "test" diff --git a/tests/runif.py b/tests/runif.py index 8f17699e..ec5da504 100644 --- a/tests/runif.py +++ b/tests/runif.py @@ -1,5 +1,6 @@ import pytest import torch +from lightning.pytorch.accelerators import find_usable_cuda_devices """ Simplified from: @@ -35,8 +36,12 @@ def __new__( reasons = [] if min_gpus: - conditions.append(torch.cuda.device_count() < min_gpus) - reasons.append(f"GPUs>={min_gpus}") + try: + find_usable_cuda_devices(min_gpus) + conditions.append(False) + except (ValueError, RuntimeError) as _: + conditions.append(True) + reasons.append(f"GPUs>={min_gpus}") reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif(