From 3d1520d049b30c95c5c2280699822699b9ec7a66 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Tue, 10 Oct 2023 16:43:04 -0400 Subject: [PATCH] 3d: need meshgrid rather than cartesian_product --- pytorch_finufft/functional.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_finufft/functional.py b/pytorch_finufft/functional.py index dc465c6..3f388e7 100644 --- a/pytorch_finufft/functional.py +++ b/pytorch_finufft/functional.py @@ -815,12 +815,15 @@ def backward( # type: ignore[override] # wrt points start_points = -(torch.tensor(grad_output.shape, device=device) // 2) end_points = start_points + torch.tensor(grad_output.shape, device=device) - coord_ramps = torch.cartesian_prod( - *( - torch.arange(start, end, device=device) - for start, end in zip(start_points, end_points) + coord_ramps = torch.stack( + torch.meshgrid( + *( + torch.arange(start, end, device=device) + for start, end in zip(start_points, end_points) + ), + indexing="ij", ) - ).to(device) + ) # we can't batch in 1d case so we squeeze and fix up the ouput later ramped_grad_output = (