Skip to content

Commit

Permalink
Make mlc.torch.tonp() gracefully handle non-Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mxbi authored Dec 31, 2018
1 parent b8f6f3e commit b8a2dda
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mlcrate/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ def _check_torch_import():

def tonp(tensor):
"""Takes any PyTorch tensor and converts it to a numpy array or scalar as appropiate.
When given something that isn't a PyTorch tensor, it will attempt to convert to a NumPy array or scalar anyway.
Not heavily optimized."""
arr = tensor.data.detach().cpu().numpy()
_check_torch_import()
if isinstance(tensor, torch.Tensor):
arr = tensor.data.detach().cpu().numpy()
else: # It's not a tensor! We'll handle it anyway
arr = np.array(tensor)
if arr.shape == ():
return np.asscalar(arr)
else:
Expand Down

0 comments on commit b8a2dda

Please sign in to comment.