From 17defac9a7f17a1b99ada769de22f41a81a7e7d7 Mon Sep 17 00:00:00 2001 From: Yael Balbastre Date: Tue, 28 Feb 2023 17:38:37 -0500 Subject: [PATCH] FIX(jit): make meshgrid work when JIT deactivated --- interpol/jit_utils.py | 56 +++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/interpol/jit_utils.py b/interpol/jit_utils.py index b17be50..8bad7ad 100644 --- a/interpol/jit_utils.py +++ b/interpol/jit_utils.py @@ -383,37 +383,47 @@ def dot_multi(x, y, dim: List[int], keepdim: bool = False): return dt -if torch_version('>=', (1, 10)): - @torch.jit.script - def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: - return torch.meshgrid(x, indexing='ij') - @torch.jit.script - def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: - return torch.meshgrid(x, indexing='xy') -else: - @torch.jit.script - def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: - return torch.meshgrid(x) - @torch.jit.script - def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: - grid = torch.meshgrid(x) - if len(grid) > 1: - grid[0] = grid[0].transpose(0, 1) - grid[1] = grid[1].transpose(0, 1) - return grid - # cartesian_prod takes multiple inout tensors as input in eager mode # but takes a list of tensor in jit mode. This is a helper that works # in both cases. if not int(os.environ.get('PYTORCH_JIT', '1')): cartesian_prod = lambda x: torch.cartesian_prod(*x) - meshgrid_ij_list = meshgrid_ij - meshgrid_xy_list = meshgrid_xy - meshgrid_ij = lambda x: meshgrid_ij_list(*x) - meshgrid_xy = lambda x: meshgrid_xy_list(*x) + if torch_version('>=', (1, 10)): + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x, indexing='ij') + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x, indexing='xy') + else: + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(*x) + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + grid = torch.meshgrid(*x) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid + else: cartesian_prod = torch.cartesian_prod + if torch_version('>=', (1, 10)): + @torch.jit.script + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x, indexing='ij') + @torch.jit.script + def meshgrid_xy(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x, indexing='xy') + else: + @torch.jit.script + def meshgrid_ij(x: List[torch.Tensor]) -> List[torch.Tensor]: + return torch.meshgrid(x) + @torch.jit.script + def meshgrid_xyt(x: List[torch.Tensor]) -> List[torch.Tensor]: + grid = torch.meshgrid(x) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid meshgrid = meshgrid_ij