diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index dc465c6..3f388e7 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -815,12 +815,15 @@ def backward( # type: ignore[override] # wrt points start_points = -(torch.tensor(grad_output.shape, device=device) // 2) end_points = start_points + torch.tensor(grad_output.shape, device=device) - coord_ramps = torch.cartesian_prod( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) + coord_ramps = torch.stack( + torch.meshgrid( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ), + indexing="ij", ) - ).to(device) + ) # we can't batch in 1d case so we squeeze and fix up the ouput later ramped_grad_output = (