Skip to content

Commit

Permalink
3d: need meshgrid rather than cartesian_product
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Oct 10, 2023
1 parent cea3856 commit 3d1520d
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,12 +815,15 @@ def backward( # type: ignore[override]
# wrt points
start_points = -(torch.tensor(grad_output.shape, device=device) // 2)
end_points = start_points + torch.tensor(grad_output.shape, device=device)
coord_ramps = torch.cartesian_prod(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
coord_ramps = torch.stack(
torch.meshgrid(
*(
torch.arange(start, end, device=device)
for start, end in zip(start_points, end_points)
),
indexing="ij",
)
).to(device)
)

# we can't batch in 1d case so we squeeze and fix up the ouput later
ramped_grad_output = (
Expand Down

0 comments on commit 3d1520d

Please sign in to comment.