From a44168a7100cf7199538bf569c78a413673c6445 Mon Sep 17 00:00:00 2001
From: mrava87 <matteoravasi@gmail.com>
Date: Wed, 1 May 2024 10:25:57 +0300
Subject: [PATCH] feature: added kwargs to FFTND

---
 pylops/signalprocessing/fftnd.py | 43 ++++++++++++++++++++++++--------
 1 file changed, 33 insertions(+), 10 deletions(-)

diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py
index a33f4918..d081072b 100644
--- a/pylops/signalprocessing/fftnd.py
+++ b/pylops/signalprocessing/fftnd.py
@@ -29,6 +29,7 @@ def __init__(
         ifftshift_before: bool = False,
         fftshift_after: bool = False,
         dtype: DTypeLike = "complex128",
+        **kwargs_fft,
     ) -> None:
         super().__init__(
             dims=dims,
@@ -45,7 +46,7 @@ def __init__(
             warnings.warn(
                 f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}."
             )
-
+        self._kwargs_fft = kwargs_fft
         self._norm_kwargs = {"norm": None}  # equivalent to "backward" in Numpy/Scipy
         if self.norm is _FFTNorms.ORTHO:
             self._norm_kwargs["norm"] = "ortho"
@@ -61,13 +62,17 @@ def _matvec(self, x: NDArray) -> NDArray:
         if not self.clinear:
             x = np.real(x)
         if self.real:
-            y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = np.fft.rfftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
             # Apply scaling to obtain a correct adjoint for this operator
             y = np.swapaxes(y, -1, self.axes[-1])
             y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
             y = np.swapaxes(y, self.axes[-1], -1)
         else:
-            y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = np.fft.fftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
         if self.norm is _FFTNorms.ONE_OVER_N:
             y *= self._scale
         y = y.astype(self.cdtype)
@@ -85,9 +90,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
             x = np.swapaxes(x, -1, self.axes[-1])
             x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
             x = np.swapaxes(x, self.axes[-1], -1)
-            y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = np.fft.irfftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
         else:
-            y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = np.fft.ifftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
         if self.norm is _FFTNorms.NONE:
             y *= self._scale
         for ax, nfft in zip(self.axes, self.nffts):
@@ -122,6 +131,7 @@ def __init__(
         ifftshift_before: bool = False,
         fftshift_after: bool = False,
         dtype: DTypeLike = "complex128",
+        **kwargs_fft,
     ) -> None:
         super().__init__(
             dims=dims,
@@ -134,7 +144,7 @@ def __init__(
             fftshift_after=fftshift_after,
             dtype=dtype,
         )
-
+        self._kwargs_fft = kwargs_fft
         self._norm_kwargs = {"norm": None}  # equivalent to "backward" in Numpy/Scipy
         if self.norm is _FFTNorms.ORTHO:
             self._norm_kwargs["norm"] = "ortho"
@@ -151,13 +161,17 @@ def _matvec(self, x: NDArray) -> NDArray:
         if not self.clinear:
             x = np.real(x)
         if self.real:
-            y = sp_fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = sp_fft.rfftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
             # Apply scaling to obtain a correct adjoint for this operator
             y = np.swapaxes(y, -1, self.axes[-1])
             y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
             y = np.swapaxes(y, self.axes[-1], -1)
         else:
-            y = sp_fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = sp_fft.fftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
         if self.norm is _FFTNorms.ONE_OVER_N:
             y *= self._scale
         if self.fftshift_after.any():
@@ -175,9 +189,13 @@ def _rmatvec(self, x: NDArray) -> NDArray:
             x = np.swapaxes(x, -1, self.axes[-1])
             x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
             x = np.swapaxes(x, self.axes[-1], -1)
-            y = sp_fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = sp_fft.irfftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
         else:
-            y = sp_fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+            y = sp_fft.ifftn(
+                x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft
+            )
         if self.norm is _FFTNorms.NONE:
             y *= self._scale
         for ax, nfft in zip(self.axes, self.nffts):
@@ -209,6 +227,7 @@ def FFTND(
     engine: str = "scipy",
     dtype: DTypeLike = "complex128",
     name: str = "F",
+    **kwargs_fft,
 ):
     r"""N-dimensional Fast-Fourier Transform.
 
@@ -311,6 +330,8 @@ def FFTND(
         .. versionadded:: 2.0.0
 
         Name of operator (to be used by :func:`pylops.utils.describe.describe`)
+    **kwargs_fft
+            Arbitrary keyword arguments to be passed to the selected fft method
 
     Attributes
     ----------
@@ -396,6 +417,7 @@ def FFTND(
             ifftshift_before=ifftshift_before,
             fftshift_after=fftshift_after,
             dtype=dtype,
+            **kwargs_fft,
         )
     elif engine == "scipy":
         f = _FFTND_scipy(
@@ -408,6 +430,7 @@ def FFTND(
             ifftshift_before=ifftshift_before,
             fftshift_after=fftshift_after,
             dtype=dtype,
+            **kwargs_fft,
         )
     else:
         raise NotImplementedError("engine must be numpy or scipy")