Skip to content

Commit

Permalink
feature: added kwargs to FFTND
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed May 1, 2024
1 parent 74e8c68 commit a44168a
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions pylops/signalprocessing/fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
ifftshift_before: bool = False,
fftshift_after: bool = False,
dtype: DTypeLike = "complex128",
**kwargs_fft,
) -> None:
super().__init__(
dims=dims,
Expand All @@ -45,7 +46,7 @@ def __init__(
warnings.warn(
f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}."
)

self._kwargs_fft = kwargs_fft
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
if self.norm is _FFTNorms.ORTHO:
self._norm_kwargs["norm"] = "ortho"
Expand All @@ -61,13 +62,17 @@ def _matvec(self, x: NDArray) -> NDArray:
if not self.clinear:
x = np.real(x)
if self.real:
y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = np.fft.rfftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axes[-1], -1)
else:
y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = np.fft.fftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
Expand All @@ -85,9 +90,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
x = np.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axes[-1], -1)
y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = np.fft.irfftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
else:
y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = np.fft.ifftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
Expand Down Expand Up @@ -122,6 +131,7 @@ def __init__(
ifftshift_before: bool = False,
fftshift_after: bool = False,
dtype: DTypeLike = "complex128",
**kwargs_fft,
) -> None:
super().__init__(
dims=dims,
Expand All @@ -134,7 +144,7 @@ def __init__(
fftshift_after=fftshift_after,
dtype=dtype,
)

self._kwargs_fft = kwargs_fft
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
if self.norm is _FFTNorms.ORTHO:
self._norm_kwargs["norm"] = "ortho"
Expand All @@ -151,13 +161,17 @@ def _matvec(self, x: NDArray) -> NDArray:
if not self.clinear:
x = np.real(x)
if self.real:
y = sp_fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = sp_fft.rfftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axes[-1], -1)
else:
y = sp_fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = sp_fft.fftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
if self.fftshift_after.any():
Expand All @@ -175,9 +189,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
x = np.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axes[-1], -1)
y = sp_fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = sp_fft.irfftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
else:
y = sp_fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = sp_fft.ifftn(
x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
Expand Down Expand Up @@ -209,6 +227,7 @@ def FFTND(
engine: str = "scipy",
dtype: DTypeLike = "complex128",
name: str = "F",
**kwargs_fft,
):
r"""N-dimensional Fast-Fourier Transform.
Expand Down Expand Up @@ -311,6 +330,8 @@ def FFTND(
.. versionadded:: 2.0.0
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
**kwargs_fft
Arbitrary keyword arguments to be passed to the selected fft method
Attributes
----------
Expand Down Expand Up @@ -396,6 +417,7 @@ def FFTND(
ifftshift_before=ifftshift_before,
fftshift_after=fftshift_after,
dtype=dtype,
**kwargs_fft,
)
elif engine == "scipy":
f = _FFTND_scipy(
Expand All @@ -408,6 +430,7 @@ def FFTND(
ifftshift_before=ifftshift_before,
fftshift_after=fftshift_after,
dtype=dtype,
**kwargs_fft,
)
else:
raise NotImplementedError("engine must be numpy or scipy")
Expand Down

0 comments on commit a44168a

Please sign in to comment.