diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py index a33f4918..d081072b 100644 --- a/pylops/signalprocessing/fftnd.py +++ b/pylops/signalprocessing/fftnd.py @@ -29,6 +29,7 @@ def __init__( ifftshift_before: bool = False, fftshift_after: bool = False, dtype: DTypeLike = "complex128", + **kwargs_fft, ) -> None: super().__init__( dims=dims, @@ -45,7 +46,7 @@ def __init__( warnings.warn( f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}." ) - + self._kwargs_fft = kwargs_fft self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy if self.norm is _FFTNorms.ORTHO: self._norm_kwargs["norm"] = "ortho" @@ -61,13 +62,17 @@ def _matvec(self, x: NDArray) -> NDArray: if not self.clinear: x = np.real(x) if self.real: - y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = np.fft.rfftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) # Apply scaling to obtain a correct adjoint for this operator y = np.swapaxes(y, -1, self.axes[-1]) y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) y = np.swapaxes(y, self.axes[-1], -1) else: - y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = np.fft.fftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale y = y.astype(self.cdtype) @@ -85,9 +90,13 @@ def _rmatvec(self, x: NDArray) -> NDArray: x = np.swapaxes(x, -1, self.axes[-1]) x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) x = np.swapaxes(x, self.axes[-1], -1) - y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = np.fft.irfftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) else: - y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = np.fft.ifftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) if self.norm is _FFTNorms.NONE: y *= self._scale for ax, nfft in zip(self.axes, self.nffts): @@ -122,6 +131,7 @@ def __init__( ifftshift_before: bool = False, fftshift_after: bool = False, dtype: DTypeLike = "complex128", + **kwargs_fft, ) -> None: super().__init__( dims=dims, @@ -134,7 +144,7 @@ def __init__( fftshift_after=fftshift_after, dtype=dtype, ) - + self._kwargs_fft = kwargs_fft self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy if self.norm is _FFTNorms.ORTHO: self._norm_kwargs["norm"] = "ortho" @@ -151,13 +161,17 @@ def _matvec(self, x: NDArray) -> NDArray: if not self.clinear: x = np.real(x) if self.real: - y = sp_fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = sp_fft.rfftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) # Apply scaling to obtain a correct adjoint for this operator y = np.swapaxes(y, -1, self.axes[-1]) y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) y = np.swapaxes(y, self.axes[-1], -1) else: - y = sp_fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = sp_fft.fftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale if self.fftshift_after.any(): @@ -175,9 +189,13 @@ def _rmatvec(self, x: NDArray) -> NDArray: x = np.swapaxes(x, -1, self.axes[-1]) x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) x = np.swapaxes(x, self.axes[-1], -1) - y = sp_fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = sp_fft.irfftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) else: - y = sp_fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = sp_fft.ifftn( + x, s=self.nffts, axes=self.axes, **self._norm_kwargs, **self._kwargs_fft + ) if self.norm is _FFTNorms.NONE: y *= self._scale for ax, nfft in zip(self.axes, self.nffts): @@ -209,6 +227,7 @@ def FFTND( engine: str = "scipy", dtype: DTypeLike = "complex128", name: str = "F", + **kwargs_fft, ): r"""N-dimensional Fast-Fourier Transform. @@ -311,6 +330,8 @@ def FFTND( .. versionadded:: 2.0.0 Name of operator (to be used by :func:`pylops.utils.describe.describe`) + **kwargs_fft + Arbitrary keyword arguments to be passed to the selected fft method Attributes ---------- @@ -396,6 +417,7 @@ def FFTND( ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, dtype=dtype, + **kwargs_fft, ) elif engine == "scipy": f = _FFTND_scipy( @@ -408,6 +430,7 @@ def FFTND( ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, dtype=dtype, + **kwargs_fft, ) else: raise NotImplementedError("engine must be numpy or scipy")