From 7f2a0aa3b7b5c109eff6f8049e70eff74220d52c Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Sat, 6 Apr 2024 17:05:28 -0400 Subject: [PATCH] Update set_global_device parameter in initialize_torch_settings function --- bnpm/torch_helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bnpm/torch_helpers.py b/bnpm/torch_helpers.py index cc2ace6..4ee219b 100644 --- a/bnpm/torch_helpers.py +++ b/bnpm/torch_helpers.py @@ -224,7 +224,7 @@ def initialize_torch_settings( enable_cudnn: Optional[bool] = None, deterministic_cudnn: Optional[bool] = None, deterministic_torch: Optional[bool] = None, - set_global_device: Union[str, torch.device, bool] = False, + set_global_device: Optional[str, torch.device] = None, init_linalg: bool = True, init_linalg_device: Union[str, torch.device] = 'cuda:0', ) -> None: @@ -266,7 +266,7 @@ def initialize_torch_settings( torch.backends.cudnn.deterministic = False if deterministic_torch: torch.set_deterministic(False) - if set_global_device is not False: + if set_global_device is not None: torch.cuda.set_device(set_global_device) ## Initialize linalg libarary