Skip to content

Commit

Permalink
RF PT get_device (Tensor.device), fix with dev index
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 28, 2023
1 parent d96c93b commit 200b7a4
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,13 @@ def get_new_dim_raw(raw_tensor: torch.Tensor, axis: int, *, name: str) -> Dim:
@staticmethod
def get_device(x: Tensor[torch.Tensor]) -> Optional[str]:
"""device"""
if x.raw_tensor is None:
raw_tensor: torch.Tensor = x.raw_tensor
if raw_tensor is None:
return None
return x.raw_tensor.device.type
dev = raw_tensor.device
if dev.index is None:
return dev.type
return f"{dev.type}:{dev.index}"

@staticmethod
def copy_to_device(x: Tensor, device: Optional[str]) -> Tensor:
Expand Down

0 comments on commit 200b7a4

Please sign in to comment.