Skip to content

Commit

Permalink
refactor initialize_torch_settings
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 6, 2024
1 parent daf4adb commit c35e261
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,19 @@ def set_device(


def initialize_torch_settings(
device: Union[str, torch.device] = 'cuda:0',
benchmark: Optional[bool] = None,
enable_cudnn: Optional[bool] = None,
deterministic_cudnn: Optional[bool] = None,
deterministic_torch: Optional[bool] = None,
set_global_device: bool = True,
set_global_device: Union[str, torch.device, bool] = False,
init_linalg: bool = True,
init_linalg_device: Union[str, torch.device] = 'cuda:0',
) -> None:
"""
Initalizes some CUDA libraries and sets some environment variables. \n
RH 2024
Args:
device (Union[str, torch.device]):
The device to use.
benchmark (Optional[bool]):
If ``True``, sets torch.backends.cudnn.benchmark to ``True``.\n
This results in the built-in cudnn auto-tuner to find the best
Expand All @@ -250,17 +248,16 @@ def initialize_torch_settings(
If ``True``, sets torch.set_deterministic to ``True``.\n
This makes torch deterministic. It may slow down operations.
set_global_device (bool):
If ``True``, sets the global device to the provided device.\n
This is discouraged in favor of explicit device setting, but useful
for when you want to set the device globally.
If ``False``, does not set the global device. If a string or torch.device,
sets the global device to the specified device.
init_linalg (bool):
If ``True``, initializes the linalg library. This is necessary to
avoid a bug. Often solves the error: "RuntimeError: lazy wrapper
should be called at most once".
should be called at most once". (Default is ``True``)
init_linalg_device (str):
The device to use for initializing the linalg library. Either a
string or a torch.device. (Default is ``'cuda:0'``)
"""
if type(device) is str:
device = torch.device(device)

if benchmark is not None:
torch.backends.cudnn.benchmark = benchmark
if enable_cudnn:
Expand All @@ -269,14 +266,16 @@ def initialize_torch_settings(
torch.backends.cudnn.deterministic = False
if deterministic_torch:
torch.set_deterministic(False)
if set_global_device:
torch.cuda.set_device(device)
if set_global_device is not False:
torch.cuda.set_device(set_global_device)

## Initialize linalg libarary
## https://github.com/pytorch/pytorch/issues/90613
if type(init_linalg_device) is str:
init_linalg_device = torch.device(init_linalg_device)
if init_linalg:
torch.inverse(torch.ones((1, 1), device=device))
torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=device))
torch.inverse(torch.ones((1, 1), device=init_linalg_device))
torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=init_linalg_device))



Expand Down

0 comments on commit c35e261

Please sign in to comment.