Skip to content

Commit

Permalink
fix(convert): torch.from_numpy requires positive strides
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Sep 6, 2023
1 parent 59a3109 commit c616460
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions zetta_utils/tensor_ops/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def to_torch(data: Tensor, device: torch.types.Device = None) -> torch.Tensor:
if data.max() > np.uint64(2 ** 63 - 1):
raise ValueError("Unable to convert uint64 dtype to int64")
data = data.astype(np.int64)
if any(v < 0 for v in data.strides): # torch.from_numpy does not support negative strides
data = data.copy("K")
result = torch.from_numpy(data).to(device) # type: ignore # pytorch bug

return result
Expand Down

0 comments on commit c616460

Please sign in to comment.