From 4d44934ebe4f3f24e63a1c59a4a906b503ea4de6 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu Date: Mon, 11 Mar 2024 18:18:53 +0200 Subject: [PATCH] Allow accelerator to instantiate the device --- accelerator/hpu_accelerator.py | 5 ++--- deepspeed/runtime/engine.py | 4 ++-- deepspeed/runtime/zero/utils.py | 3 +-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 30b115e8b1ab..3da4d3637dd2 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -40,9 +40,8 @@ def handles_memory_backpressure(self): return True def device_name(self, device_index=None): - if device_index is None: - return 'hpu' - return 'hpu:{}'.format(device_index) + # ignoring device_index. + return 'hpu' def device(self, device_index=None): return torch.device(self.device_name(device_index)) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5c1202ba06ae..24aaac80b0cc 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -977,13 +977,13 @@ def _set_distributed_vars(self, args): device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank if device_rank >= 0: get_accelerator().set_device(device_rank) - self.device = torch.device(get_accelerator().device_name(), device_rank) + self.device = torch.device(get_accelerator().device_name(device_rank)) self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() else: self.world_size = 1 self.global_rank = 0 - self.device = torch.device(get_accelerator().device_name()) + self.device = get_accelerator().device() # Configure based on command line arguments def _configure_with_arguments(self, args, mpu): diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index f61715bd4387..3993689294c7 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -65,8 +65,7 @@ def get_lst_from_rank0(lst: List[int]) -> None: lst_tensor = torch.tensor( lst if dist.get_rank() == 0 else [-1] * len(lst), dtype=int, - # device=get_accelerator().current_device_name(), - device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])), + device=get_accelerator().device(os.environ["LOCAL_RANK"]), requires_grad=False, ) dist.broadcast(lst_tensor, src=0, async_op=False)