Skip to content

Commit

Permalink
FIX(jit): make meshgrid work when JIT deactivated
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Feb 28, 2023
1 parent ee3b381 commit 17defac
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions interpol/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,37 +383,47 @@ def dot_multi(x, y, dim: List[int], keepdim: bool = False):
return dt


if torch_version('>=', (1, 10)):
@torch.jit.script
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x, indexing='ij')
@torch.jit.script
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x, indexing='xy')
else:
@torch.jit.script
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x)
@torch.jit.script
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
grid = torch.meshgrid(x)
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid


# cartesian_prod takes multiple inout tensors as input in eager mode
# but takes a list of tensor in jit mode. This is a helper that works
# in both cases.
if not int(os.environ.get('PYTORCH_JIT', '1')):
cartesian_prod = lambda x: torch.cartesian_prod(*x)
meshgrid_ij_list = meshgrid_ij
meshgrid_xy_list = meshgrid_xy
meshgrid_ij = lambda x: meshgrid_ij_list(*x)
meshgrid_xy = lambda x: meshgrid_xy_list(*x)
if torch_version('>=', (1, 10)):
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(*x, indexing='ij')
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(*x, indexing='xy')
else:
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(*x)
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
grid = torch.meshgrid(*x)
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid

else:
cartesian_prod = torch.cartesian_prod
if torch_version('>=', (1, 10)):
@torch.jit.script
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x, indexing='ij')
@torch.jit.script
def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x, indexing='xy')
else:
@torch.jit.script
def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]:
return torch.meshgrid(x)
@torch.jit.script
def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]:
grid = torch.meshgrid(x)
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid


meshgrid = meshgrid_ij

0 comments on commit 17defac

Please sign in to comment.