Skip to content

Commit

Permalink
WIP bring forward and backward into one class
Browse files Browse the repository at this point in the history
  • Loading branch information
eickenberg committed Oct 5, 2023
1 parent 937d7ef commit 48410ff
Showing 1 changed file with 107 additions and 110 deletions.
217 changes: 107 additions & 110 deletions pytorch_finufft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,142 +1601,139 @@ def backward(
# Consolidated forward function for all 1D, 2D, and 3D problems for nufft type 1
###############################################################################

# This function takes a ctx object and is supposed to later replace all type 1
# functions above for all dimensionalities.

def finufft_type1(
ctx: Any,
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, tuple[int, int], tuple[int, int, int]],
out: Optional[torch.Tensor]=None,
fftshift: bool=False,
finufftkwargs: dict[str, Union[int, float]]=None):
"""
Evaluates the Type 1 NUFFT on the inputs.
class finufft_type1(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
points: torch.Tensor,
values: torch.Tensor,
output_shape: Union[int, tuple[int, int], tuple[int, int, int]],
out: Optional[torch.Tensor]=None,
fftshift: bool=False,
finufftkwargs: dict[str, Union[int, float]]=None):
"""
Evaluates the Type 1 NUFFT on the inputs.
"""
"""

if out is not None:
print("In-place results are not yet implemented")
# All this requires is a check on the out array to make sure it is the
# correct shape.
if out is not None:
print("In-place results are not yet implemented")
# All this requires is a check on the out array to make sure it is the
# correct shape.

err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately
err._type1_checks(*points.T, values, output_shape) # revisit these error checks to take into account the shape of points instead of passing them separately


if finufftkwargs is None:
finufftkwargs = dict()
finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)
if finufftkwargs is None:
finufftkwargs = dict()
finufftkwargs = {k: v for k, v in finufftkwargs.items()}
_mode_ordering = finufftkwargs.pop("modeord", 1)
_i_sign = finufftkwargs.pop("isign", -1)

if fftshift:
# TODO -- this check should be done elsewhere? or error msg changed
# to note instead that there is a conflict in fftshift
if _mode_ordering != 1:
raise ValueError(
"Double specification of ordering; only one of fftshift and modeord should be provided"
)
_mode_ordering = 0
if fftshift:
# TODO -- this check should be done elsewhere? or error msg changed
# to note instead that there is a conflict in fftshift
if _mode_ordering != 1:
raise ValueError(
"Double specification of ordering; only one of fftshift and modeord should be provided"
)
_mode_ordering = 0

ctx.save_for_backward(*points.T, values)
ctx.save_for_backward(*points.T, values)

ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs
ctx.isign = _i_sign
ctx.mode_ordering = _mode_ordering
ctx.finufftkwargs = finufftkwargs

finufft_out = torch.from_numpy(
finufft.nufft3d1(
*points.data.T.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
isign=_i_sign,
**finufftkwargs,
finufft_out = torch.from_numpy(
finufft.nufft3d1(
*points.data.T.numpy(),
values.data.numpy(),
output_shape,
modeord=_mode_ordering,
isign=_i_sign,
**finufftkwargs,
)
)
)

return finufft_out
return finufft_out

@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Union[torch.Tensor, None], ...]:
"""
Implements derivatives wrt. each argument in the forward method.
Parameters
----------
ctx : Any
Pytorch context object.
grad_output : torch.Tensor
Backpass gradient output
Returns
-------
tuple[Union[torch.Tensor, None], ...]
A tuple of derivatives wrt. each argument in the forward method
"""
_i_sign = -1 * ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs

def backward_type1(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Union[torch.Tensor, None], ...]:
"""
Implements derivatives wrt. each argument in the forward method.
Parameters
----------
ctx : Any
Pytorch context object.
grad_output : torch.Tensor
Backpass gradient output
Returns
-------
tuple[Union[torch.Tensor, None], ...]
A tuple of derivatives wrt. each argument in the forward method
"""
_i_sign = -1 * ctx.isign
_mode_ordering = ctx.mode_ordering
finufftkwargs = ctx.finufftkwargs

points, values = ctx.saved_tensors

points, values = ctx.saved_tensors
start_points = -np.array(grad_output.shape) // 2
end_points = start_points + grad_output.shape
slices = tuple(slice(start, end) for start, end in zip(start_points, end_points))

start_points = -np.array(grad_output.shape) // 2
end_points = start_points + grad_output.shape
slices = tuple(slice(start, end) for start, end in zip(start_points, end_points))
coord_ramps = torch.mgrid[slices]

coord_ramps = torch.mgrid[slices]
grads_points = None
grad_values = None

grads_points = None
grad_values = None
if ctx.needs_input_grad[0]:
# wrt points

if ctx.needs_input_grad[0]:
# wrt points
if _mode_ordering != 0:
coord_ramps = torch.fft.ifftshift(coord_ramps)

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign

grads_points = []
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy(),
ramp.data.numpy(),
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
))
grad_points = (backprop_ramp.conj() * values).real
grads_points.append(grad_points)

grads_points = torch.stack(grads_points)

if _mode_ordering != 0:
coord_ramps = torch.fft.ifftshift(coord_ramps)

ramped_grad_output = coord_ramps * grad_output[np.newaxis] * 1j * _i_sign
if ctx.needs_input_grad[1]:
np_grad_output = grad_output.data.numpy()

grads_points = []
for ramp in ramped_grad_output: # we can batch this into finufft
backprop_ramp = torch.from_numpy(
grad_values = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy(),
ramp.data.numpy(),
*points.T.numpy()
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
))
grad_points = (backprop_ramp.conj() * values).real
grads_points.append(grad_points)

grads_points = torch.stack(grads_points)

if ctx.needs_input_grad[1]:
np_grad_output = grad_output.data.numpy()

grad_values = torch.from_numpy(
finufft.nufft3d2(
*points.T.numpy()
np_grad_output,
isign=_i_sign,
modeord=_mode_ordering,
**finufftkwargs,
)
)
)

return (
grads_points,
grad_values,
None,
None,
None,
None,
)
return (
grads_points,
grad_values,
None,
None,
None,
None,
)

0 comments on commit 48410ff

Please sign in to comment.