Skip to content

Commit

Permalink
Allow accelerator to instantiate the device
Browse files Browse the repository at this point in the history
  • Loading branch information
nelyahu committed Mar 11, 2024
1 parent 535a908 commit 4d44934
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
5 changes: 2 additions & 3 deletions accelerator/hpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def handles_memory_backpressure(self):
return True

def device_name(self, device_index=None):
if device_index is None:
return 'hpu'
return 'hpu:{}'.format(device_index)
# ignoring device_index.
return 'hpu'

def device(self, device_index=None):
return torch.device(self.device_name(device_index))
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,13 @@ def _set_distributed_vars(self, args):
device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank
if device_rank >= 0:
get_accelerator().set_device(device_rank)
self.device = torch.device(get_accelerator().device_name(), device_rank)
self.device = torch.device(get_accelerator().device_name(device_rank))
self.world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
else:
self.world_size = 1
self.global_rank = 0
self.device = torch.device(get_accelerator().device_name())
self.device = get_accelerator().device()

# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def get_lst_from_rank0(lst: List[int]) -> None:
lst_tensor = torch.tensor(
lst if dist.get_rank() == 0 else [-1] * len(lst),
dtype=int,
# device=get_accelerator().current_device_name(),
device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
device=get_accelerator().device(os.environ["LOCAL_RANK"]),
requires_grad=False,
)
dist.broadcast(lst_tensor, src=0, async_op=False)
Expand Down

0 comments on commit 4d44934

Please sign in to comment.