diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index dbf44e8..cc2ace6 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -220,21 +220,19 @@ def set_device( def initialize_torch_settings( - device: Union[str, torch.device] = 'cuda:0', benchmark: Optional[bool] = None, enable_cudnn: Optional[bool] = None, deterministic_cudnn: Optional[bool] = None, deterministic_torch: Optional[bool] = None, - set_global_device: bool = True, + set_global_device: Union[str, torch.device, bool] = False, init_linalg: bool = True, + init_linalg_device: Union[str, torch.device] = 'cuda:0', ) -> None: """ Initalizes some CUDA libraries and sets some environment variables. \n RH 2024 Args: - device (Union[str, torch.device]): - The device to use. benchmark (Optional[bool]): If ``True``, sets torch.backends.cudnn.benchmark to ``True``.\n This results in the built-in cudnn auto-tuner to find the best @@ -250,17 +248,16 @@ def initialize_torch_settings( If ``True``, sets torch.set_deterministic to ``True``.\n This makes torch deterministic. It may slow down operations. set_global_device (bool): - If ``True``, sets the global device to the provided device.\n - This is discouraged in favor of explicit device setting, but useful - for when you want to set the device globally. + If ``False``, does not set the global device. If a string or torch.device, + sets the global device to the specified device. init_linalg (bool): If ``True``, initializes the linalg library. This is necessary to avoid a bug. Often solves the error: "RuntimeError: lazy wrapper - should be called at most once". + should be called at most once". (Default is ``True``) + init_linalg_device (str): + The device to use for initializing the linalg library. Either a + string or a torch.device. (Default is ``'cuda:0'``) """ - if type(device) is str: - device = torch.device(device) - if benchmark is not None: torch.backends.cudnn.benchmark = benchmark if enable_cudnn: @@ -269,14 +266,16 @@ def initialize_torch_settings( torch.backends.cudnn.deterministic = False if deterministic_torch: torch.set_deterministic(False) - if set_global_device: - torch.cuda.set_device(device) + if set_global_device is not False: + torch.cuda.set_device(set_global_device) ## Initialize linalg libarary ## https://github.com/pytorch/pytorch/issues/90613 + if type(init_linalg_device) is str: + init_linalg_device = torch.device(init_linalg_device) if init_linalg: - torch.inverse(torch.ones((1, 1), device=device)) - torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=device)) + torch.inverse(torch.ones((1, 1), device=init_linalg_device)) + torch.linalg.qr(torch.as_tensor([[1.0, 2.0], [3.0, 4.0]], device=init_linalg_device))