diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 0daf167e14e0..3ecbe9738f48 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -30,12 +30,13 @@ def __init__(self): def _init_pynvml(self): global pynvml try: - import pynvml + import pynvml as tmp_pynvml except ImportError: return try: - pynvml.nvmlInit() - except pynvml.NVMLError: + tmp_pynvml.nvmlInit() + pynvml = tmp_pynvml + except tmp_pynvml.NVMLError: return def is_synchronized_device(self):