Skip to content

Commit

Permalink
Fix DST bounds across torch versions (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty authored Apr 18, 2023
1 parent 4f3edfa commit 414ed52
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
7 changes: 5 additions & 2 deletions interpol/bounds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from enum import Enum
from typing import Optional
from .jit_utils import floor_div
Tensor = torch.Tensor


Expand Down Expand Up @@ -69,12 +70,14 @@ def transform(self, i, n: int) -> Optional[Tensor]:
i = i.remainder(n2)
x = torch.where(i == 0, zero, one)
x = torch.where(i.remainder(n + 1) == n, zero, x)
x = torch.where(i.floor_divide(n+1).remainder(2) > 0, -x, x)
i = floor_div(i, n+1)
x = torch.where(torch.remainder(i, 2) > 0, -x, x)
return x
elif self.type == 5: # dst2
i = torch.where(i < 0, n - 1 - i, i)
x = torch.ones([1], dtype=torch.int8, device=i.device)
x = torch.where(i.floor_divide(n).remainder(2) > 0, -x, x)
i = floor_div(i, n)
x = torch.where(torch.remainder(i, 2) > 0, -x, x)
return x
elif self.type == 0: # zero
one = torch.ones([1], dtype=torch.int8, device=i.device)
Expand Down
14 changes: 14 additions & 0 deletions interpol/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,17 @@ def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]:


meshgrid = meshgrid_ij


# In torch < 1.6, div applied to integer tensor performed a floor_divide
# In torch > 1.6, it performs a true divide.
# Floor division must be done using `floor_divide`, but it was buggy
# until torch 1.13 (it was doing a trunc divide instead of a floor divide).
# There was at some point a deprecation warning for floor_divide, but it
# seems to have been lifted afterwards. In torch >= 1.13, floor_divide
# performs a correct floor division.
# Since we only apply floor_divide ot positive values, we are fine.
if torch_version('<', (1, 6)):
floor_div = torch.div
else:
floor_div = torch.floor_divide

0 comments on commit 414ed52

Please sign in to comment.