From 21214db045b198796b54008764410db00c2aec66 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Fri, 12 Jul 2024 08:19:13 -0400 Subject: [PATCH] Refactor set_device function to support multiple device types and improve device selection logic --- bnpm/torch_helpers.py | 67 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index 287ec1e..a54a3a8 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -199,6 +199,7 @@ def wrapper(*args, **kwargs): def set_device( use_GPU: bool = True, device_num: int = 0, + device_types: List[str] = ['cuda', 'mps', 'xpu', 'cpu'], verbose: bool = True ) -> str: """ @@ -216,6 +217,10 @@ def set_device( (Default is ``True``) device_num (int): Specifies the index of the GPU to use. (Default is ``0``) + device_types (List[str]): + The types and order of devices to attempt to use. The first device + type that is available will be used. Options are ``'cuda'``, + ``'mps'``, ``'xpu'``, and ``'cpu'``. verbose (bool): Determines whether to print the device information. \n * ``True``: the function will print out the device information. @@ -228,20 +233,64 @@ def set_device( A string specifying the device, either *"cpu"* or *"cuda:"*. """ - if use_GPU: - print(f'devices available: {[torch.cuda.get_device_properties(ii) for ii in range(torch.cuda.device_count())]}') if verbose else None - device = torch.device(device_num) if torch.cuda.is_available() else "cpu" - if device == "cpu": - print("no GPU available. Using CPU.") if verbose else None - else: - print(f"Using device: '{device}': {torch.cuda.get_device_properties(device_num)}") if verbose else None + devices = list_available_devices() + + if not use_GPU: + device = 'cpu' else: - device = "cpu" - print(f"device: '{device}'") if verbose else None + device = None + for device_type in device_types: + if len(devices[device_type]) > 0: + device = devices[device_type][device_num] + break + + if verbose: + print(f'Using device: {device}') return device +def list_available_devices() -> dict: + """ + Lists all available PyTorch devices on the system. + RH 2024 + + Returns: + (dict): + A dictionary with device types as keys and lists of available devices as values. + """ + devices = {} + + # Check for CPU devices + if torch.cpu.is_available(): + devices['cpu'] = ['cpu'] + else: + devices['cpu'] = [] + + # Check for CUDA devices + if torch.cuda.is_available(): + devices['cuda'] = [f'cuda:{i}' for i in range(torch.cuda.device_count())] + else: + devices['cuda'] = [] + + # Check for MPS devices + if torch.backends.mps.is_available(): + devices['mps'] = ['mps:0'] + else: + devices['mps'] = [] + + # Check for XPU devices + if hasattr(torch, 'xpu'): + if torch.xpu.is_available(): + devices['xpu'] = [f'xpu:{i}' for i in range(torch.xpu.device_count())] + else: + devices['xpu'] = [] + else: + devices['xpu'] = [] + + return devices + + def initialize_torch_settings( benchmark: Optional[bool] = None, enable_cudnn: Optional[bool] = None,