Skip to content

Commit

Permalink
Refactor set_device function to support multiple device types and imp…
Browse files Browse the repository at this point in the history
…rove device selection logic
  • Loading branch information
RichieHakim committed Jul 12, 2024
1 parent e15b7fb commit 21214db
Showing 1 changed file with 58 additions and 9 deletions.
67 changes: 58 additions & 9 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def wrapper(*args, **kwargs):
def set_device(
use_GPU: bool = True,
device_num: int = 0,
device_types: List[str] = ['cuda', 'mps', 'xpu', 'cpu'],
verbose: bool = True
) -> str:
"""
Expand All @@ -216,6 +217,10 @@ def set_device(
(Default is ``True``)
device_num (int):
Specifies the index of the GPU to use. (Default is ``0``)
device_types (List[str]):
The types and order of devices to attempt to use. The first device
type that is available will be used. Options are ``'cuda'``,
``'mps'``, ``'xpu'``, and ``'cpu'``.
verbose (bool):
Determines whether to print the device information. \n
* ``True``: the function will print out the device information.
Expand All @@ -228,20 +233,64 @@ def set_device(
A string specifying the device, either *"cpu"* or
*"cuda:<device_num>"*.
"""
if use_GPU:
print(f'devices available: {[torch.cuda.get_device_properties(ii) for ii in range(torch.cuda.device_count())]}') if verbose else None
device = torch.device(device_num) if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("no GPU available. Using CPU.") if verbose else None
else:
print(f"Using device: '{device}': {torch.cuda.get_device_properties(device_num)}") if verbose else None
devices = list_available_devices()

if not use_GPU:
device = 'cpu'
else:
device = "cpu"
print(f"device: '{device}'") if verbose else None
device = None
for device_type in device_types:
if len(devices[device_type]) > 0:
device = devices[device_type][device_num]
break

if verbose:
print(f'Using device: {device}')

return device


def list_available_devices() -> dict:
"""
Lists all available PyTorch devices on the system.
RH 2024
Returns:
(dict):
A dictionary with device types as keys and lists of available devices as values.
"""
devices = {}

# Check for CPU devices
if torch.cpu.is_available():
devices['cpu'] = ['cpu']
else:
devices['cpu'] = []

# Check for CUDA devices
if torch.cuda.is_available():
devices['cuda'] = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
else:
devices['cuda'] = []

# Check for MPS devices
if torch.backends.mps.is_available():
devices['mps'] = ['mps:0']
else:
devices['mps'] = []

# Check for XPU devices
if hasattr(torch, 'xpu'):
if torch.xpu.is_available():
devices['xpu'] = [f'xpu:{i}' for i in range(torch.xpu.device_count())]
else:
devices['xpu'] = []
else:
devices['xpu'] = []

return devices


def initialize_torch_settings(
benchmark: Optional[bool] = None,
enable_cudnn: Optional[bool] = None,
Expand Down

0 comments on commit 21214db

Please sign in to comment.