diff --git a/environment-dev.yml b/environment-dev.yml index f00830c2..b2a18d23 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,6 +7,9 @@ channels: dependencies: - python>=3.6.4 - pip + - mkl + - mkl_fft + - mkl-service - numpy>=1.21.0 - scipy>=1.4.0 - pytorch>=1.2.0 diff --git a/examples/plot_fft.py b/examples/plot_fft.py index a8af9da6..10f9db47 100644 --- a/examples/plot_fft.py +++ b/examples/plot_fft.py @@ -6,173 +6,194 @@ and :py:class:`pylops.signalprocessing.FFTND` operators to apply the Fourier Transform to the model and the inverse Fourier Transform to the data. """ -import matplotlib.pyplot as plt -import numpy as np - +# import matplotlib.pyplot as plt import pylops -plt.close("all") - -############################################################################### -# Let's start by applying the one dimensional FFT to a one dimensional -# sinusoidal signal :math:`d(t)=sin(2 \pi f_0t)` using a time axis of -# lenght :math:`nt` and sampling :math:`dt` +import numpy as np +# +# +# plt.close("all") +# +# ############################################################################### +# # Let's start by applying the one dimensional FFT to a one dimensional +# # sinusoidal signal :math:`d(t)=sin(2 \pi f_0t)` using a time axis of +# # lenght :math:`nt` and sampling :math:`dt` dt = 0.005 nt = 100 t = np.arange(nt) * dt f0 = 10 nfft = 2**10 d = np.sin(2 * np.pi * f0 * t) - -FFTop = pylops.signalprocessing.FFT(dims=nt, nfft=nfft, sampling=dt, engine="numpy") -D = FFTop * d - -# Adjoint = inverse for FFT -dinv = FFTop.H * D -dinv = FFTop / D - -fig, axs = plt.subplots(1, 2, figsize=(10, 4)) -axs[0].plot(t, d, "k", lw=2, label="True") -axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted") -axs[0].legend() -axs[0].set_title("Signal") -axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2) -axs[1].set_title("Fourier Transform") -axs[1].set_xlim([0, 3 * f0]) -plt.tight_layout() +# +# FFTop = pylops.FFT(dims=nt, nfft=nfft, sampling=dt, engine="numpy") +# D = FFTop * d +# +# # Adjoint = inverse for FFT +# dinv = FFTop.H * D +# dinv = FFTop / D +# +# fig, axs = plt.subplots(1, 2, figsize=(10, 4)) +# axs[0].plot(t, d, "k", lw=2, label="True") +# axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted") +# axs[0].legend() +# axs[0].set_title("Signal") +# axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2) +# axs[1].set_title("Fourier Transform") +# axs[1].set_xlim([0, 3 * f0]) +# plt.tight_layout() +# +# ############################################################################### +# # In this example we used numpy as our engine for the ``fft`` and ``ifft``. +# # PyLops implements a second engine (``engine='fftw'``) which uses the +# # well-known `FFTW `_ via the python wrapper +# # :py:class:`pyfftw.FFTW`. This optimized fft tends to outperform the one from +# # numpy in many cases but it is not inserted in the mandatory requirements of +# # PyLops. If interested to use ``FFTW`` backend, read the `fft routines` +# # section at :ref:`performance`. +# FFTop = pylops.FFT(dims=nt, nfft=nfft, sampling=dt, engine="fftw") +# D = FFTop * d +# +# # Adjoint = inverse for FFT +# dinv = FFTop.H * D +# dinv = FFTop / D +# +# fig, axs = plt.subplots(1, 2, figsize=(10, 4)) +# axs[0].plot(t, d, "k", lw=2, label="True") +# axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted") +# axs[0].legend() +# axs[0].set_title("Signal") +# axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2) +# axs[1].set_title("Fourier Transform with fftw") +# axs[1].set_xlim([0, 3 * f0]) +# plt.tight_layout() ############################################################################### -# In this example we used numpy as our engine for the ``fft`` and ``ifft``. -# PyLops implements a second engine (``engine='fftw'``) which uses the -# well-known `FFTW `_ via the python wrapper -# :py:class:`pyfftw.FFTW`. This optimized fft tends to outperform the one from -# numpy in many cases but it is not inserted in the mandatory requirements of -# PyLops. If interested to use ``FFTW`` backend, read the `fft routines` -# section at :ref:`performance`. -FFTop = pylops.signalprocessing.FFT(dims=nt, nfft=nfft, sampling=dt, engine="fftw") +# PyLops implements a third engine (``engine='mkl_fft'``) which uses the +# well-known `mkl_fft `_ . +FFTop = pylops.FFT(dims=nt, nfft=nfft, sampling=dt, engine="mkl_fft") D = FFTop * d # Adjoint = inverse for FFT -dinv = FFTop.H * D -dinv = FFTop / D - -fig, axs = plt.subplots(1, 2, figsize=(10, 4)) -axs[0].plot(t, d, "k", lw=2, label="True") -axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted") -axs[0].legend() -axs[0].set_title("Signal") -axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2) -axs[1].set_title("Fourier Transform with fftw") -axs[1].set_xlim([0, 3 * f0]) -plt.tight_layout() +# dinv = FFTop.H * D +# dinv = FFTop / D +# +# fig, axs = plt.subplots(1, 2, figsize=(10, 4)) +# axs[0].plot(t, d, "k", lw=2, label="True") +# axs[0].plot(t, dinv.real, "--r", lw=2, label="Inverted") +# axs[0].legend() +# axs[0].set_title("Signal") +# axs[1].plot(FFTop.f[: int(FFTop.nfft / 2)], np.abs(D[: int(FFTop.nfft / 2)]), "k", lw=2) +# axs[1].set_title("Fourier Transform with mkl_fft") +# axs[1].set_xlim([0, 3 * f0]) +# plt.tight_layout() ############################################################################### # We can also apply the one dimensional FFT to to a two-dimensional # signal (along one of the first axis) -dt = 0.005 -nt, nx = 100, 20 -t = np.arange(nt) * dt -f0 = 10 -nfft = 2**10 -d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) - -FFTop = pylops.signalprocessing.FFT(dims=(nt, nx), axis=0, nfft=nfft, sampling=dt) -D = FFTop * d.ravel() - -# Adjoint = inverse for FFT -dinv = FFTop.H * D -dinv = FFTop / D -dinv = np.real(dinv).reshape(nt, nx) - -fig, axs = plt.subplots(2, 2, figsize=(10, 6)) -axs[0][0].imshow(d, vmin=-20, vmax=20, cmap="bwr") -axs[0][0].set_title("Signal") -axs[0][0].axis("tight") -axs[0][1].imshow(np.abs(D.reshape(nfft, nx)[:200, :]), cmap="bwr") -axs[0][1].set_title("Fourier Transform") -axs[0][1].axis("tight") -axs[1][0].imshow(dinv, vmin=-20, vmax=20, cmap="bwr") -axs[1][0].set_title("Inverted") -axs[1][0].axis("tight") -axs[1][1].imshow(d - dinv, vmin=-20, vmax=20, cmap="bwr") -axs[1][1].set_title("Error") -axs[1][1].axis("tight") -fig.tight_layout() - -############################################################################### -# We can also apply the two dimensional FFT to to a two-dimensional signal -dt, dx = 0.005, 5 -nt, nx = 100, 201 -t = np.arange(nt) * dt -x = np.arange(nx) * dx -f0 = 10 -nfft = 2**10 -d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) - -FFTop = pylops.signalprocessing.FFT2D( - dims=(nt, nx), nffts=(nfft, nfft), sampling=(dt, dx) -) -D = FFTop * d.ravel() - -dinv = FFTop.H * D -dinv = FFTop / D -dinv = np.real(dinv).reshape(nt, nx) - -fig, axs = plt.subplots(2, 2, figsize=(10, 6)) -axs[0][0].imshow(d, vmin=-100, vmax=100, cmap="bwr") -axs[0][0].set_title("Signal") -axs[0][0].axis("tight") -axs[0][1].imshow( - np.abs(np.fft.fftshift(D.reshape(nfft, nfft), axes=1)[:200, :]), cmap="bwr" -) -axs[0][1].set_title("Fourier Transform") -axs[0][1].axis("tight") -axs[1][0].imshow(dinv, vmin=-100, vmax=100, cmap="bwr") -axs[1][0].set_title("Inverted") -axs[1][0].axis("tight") -axs[1][1].imshow(d - dinv, vmin=-100, vmax=100, cmap="bwr") -axs[1][1].set_title("Error") -axs[1][1].axis("tight") -fig.tight_layout() - - -############################################################################### -# Finally can apply the three dimensional FFT to to a three-dimensional signal -dt, dx, dy = 0.005, 5, 3 -nt, nx, ny = 30, 21, 11 -t = np.arange(nt) * dt -x = np.arange(nx) * dx -y = np.arange(nx) * dy -f0 = 10 -nfft = 2**6 -nfftk = 2**5 - -d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) -d = np.tile(d[:, :, np.newaxis], [1, 1, ny]) - -FFTop = pylops.signalprocessing.FFTND( - dims=(nt, nx, ny), nffts=(nfft, nfftk, nfftk), sampling=(dt, dx, dy) -) -D = FFTop * d.ravel() - -dinv = FFTop.H * D -dinv = FFTop / D -dinv = np.real(dinv).reshape(nt, nx, ny) - -fig, axs = plt.subplots(2, 2, figsize=(10, 6)) -axs[0][0].imshow(d[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") -axs[0][0].set_title("Signal") -axs[0][0].axis("tight") -axs[0][1].imshow( - np.abs(np.fft.fftshift(D.reshape(nfft, nfftk, nfftk), axes=1)[:20, :, nfftk // 2]), - cmap="bwr", -) -axs[0][1].set_title("Fourier Transform") -axs[0][1].axis("tight") -axs[1][0].imshow(dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") -axs[1][0].set_title("Inverted") -axs[1][0].axis("tight") -axs[1][1].imshow(d[:, :, ny // 2] - dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") -axs[1][1].set_title("Error") -axs[1][1].axis("tight") -fig.tight_layout() +# dt = 0.005 +# nt, nx = 100, 20 +# t = np.arange(nt) * dt +# f0 = 10 +# nfft = 2**10 +# d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +# +# FFTop = pylops.signalprocessing.FFT(dims=(nt, nx), axis=0, nfft=nfft, sampling=dt) +# D = FFTop * d.ravel() +# +# # Adjoint = inverse for FFT +# dinv = FFTop.H * D +# dinv = FFTop / D +# dinv = np.real(dinv).reshape(nt, nx) +# +# fig, axs = plt.subplots(2, 2, figsize=(10, 6)) +# axs[0][0].imshow(d, vmin=-20, vmax=20, cmap="bwr") +# axs[0][0].set_title("Signal") +# axs[0][0].axis("tight") +# axs[0][1].imshow(np.abs(D.reshape(nfft, nx)[:200, :]), cmap="bwr") +# axs[0][1].set_title("Fourier Transform") +# axs[0][1].axis("tight") +# axs[1][0].imshow(dinv, vmin=-20, vmax=20, cmap="bwr") +# axs[1][0].set_title("Inverted") +# axs[1][0].axis("tight") +# axs[1][1].imshow(d - dinv, vmin=-20, vmax=20, cmap="bwr") +# axs[1][1].set_title("Error") +# axs[1][1].axis("tight") +# fig.tight_layout() +# +# ############################################################################### +# # We can also apply the two dimensional FFT to to a two-dimensional signal +# dt, dx = 0.005, 5 +# nt, nx = 100, 201 +# t = np.arange(nt) * dt +# x = np.arange(nx) * dx +# f0 = 10 +# nfft = 2**10 +# d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +# +# FFTop = pylops.FFT2D( +# dims=(nt, nx), nffts=(nfft, nfft), sampling=(dt, dx) +# ) +# D = FFTop * d.ravel() +# +# dinv = FFTop.H * D +# dinv = FFTop / D +# dinv = np.real(dinv).reshape(nt, nx) +# +# fig, axs = plt.subplots(2, 2, figsize=(10, 6)) +# axs[0][0].imshow(d, vmin=-100, vmax=100, cmap="bwr") +# axs[0][0].set_title("Signal") +# axs[0][0].axis("tight") +# axs[0][1].imshow( +# np.abs(np.fft.fftshift(D.reshape(nfft, nfft), axes=1)[:200, :]), cmap="bwr" +# ) +# axs[0][1].set_title("Fourier Transform") +# axs[0][1].axis("tight") +# axs[1][0].imshow(dinv, vmin=-100, vmax=100, cmap="bwr") +# axs[1][0].set_title("Inverted") +# axs[1][0].axis("tight") +# axs[1][1].imshow(d - dinv, vmin=-100, vmax=100, cmap="bwr") +# axs[1][1].set_title("Error") +# axs[1][1].axis("tight") +# fig.tight_layout() +# +# +# ############################################################################### +# # Finally can apply the three dimensional FFT to to a three-dimensional signal +# dt, dx, dy = 0.005, 5, 3 +# nt, nx, ny = 30, 21, 11 +# t = np.arange(nt) * dt +# x = np.arange(nx) * dx +# y = np.arange(nx) * dy +# f0 = 10 +# nfft = 2**6 +# nfftk = 2**5 +# +# d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +# d = np.tile(d[:, :, np.newaxis], [1, 1, ny]) +# +# FFTop = pylops.FFTND( +# dims=(nt, nx, ny), nffts=(nfft, nfftk, nfftk), sampling=(dt, dx, dy) +# ) +# D = FFTop * d.ravel() +# +# dinv = FFTop.H * D +# dinv = FFTop / D +# dinv = np.real(dinv).reshape(nt, nx, ny) +# +# fig, axs = plt.subplots(2, 2, figsize=(10, 6)) +# axs[0][0].imshow(d[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +# axs[0][0].set_title("Signal") +# axs[0][0].axis("tight") +# axs[0][1].imshow( +# np.abs(np.fft.fftshift(D.reshape(nfft, nfftk, nfftk), axes=1)[:20, :, nfftk // 2]), +# cmap="bwr", +# ) +# axs[0][1].set_title("Fourier Transform") +# axs[0][1].axis("tight") +# axs[1][0].imshow(dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +# axs[1][0].set_title("Inverted") +# axs[1][0].axis("tight") +# axs[1][1].imshow(d[:, :, ny // 2] - dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +# axs[1][1].set_title("Error") +# axs[1][1].axis("tight") +# fig.tight_layout() diff --git a/pylops/__init__.py b/pylops/__init__.py index 55d4ce3d..770e628f 100755 --- a/pylops/__init__.py +++ b/pylops/__init__.py @@ -47,7 +47,6 @@ from .config import * from .linearoperator import * -from .torchoperator import * from .basicoperators import * from . import ( avo, @@ -57,6 +56,8 @@ utils, waveeqprocessing, ) +from .torchoperator import * +from .signalprocessing import * from .avo.poststack import * from .avo.prestack import * from .optimization.basic import * diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py index 64444bcd..ce037cb9 100644 --- a/pylops/signalprocessing/fft.py +++ b/pylops/signalprocessing/fft.py @@ -15,10 +15,15 @@ from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray pyfftw_message = deps.pyfftw_import("the fft module") +mkl_fft_message = deps.mkl_fft_import("the mkl fft module") if pyfftw_message is None: import pyfftw +if mkl_fft_message is None: + from mkl_fft import _numpy_fft + from mkl_fft import _scipy_fft_backend + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -365,6 +370,92 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: return self._rmatvec(y) / self._scale +class _FFT_mklfft(_BaseFFT): + """One-dimensional Fast-Fourier Transform using mkl_fft""" + + def __init__( + self, + dims: Union[int, InputDimsLike], + axis: int = -1, + nfft: Optional[int] = None, + sampling: float = 1.0, + norm: str = "ortho", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + ) -> None: + super().__init__( + dims=dims, + axis=axis, + nfft=nfft, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + self._norm_kwargs = {"norm": None} + if self.norm is _FFTNorms.ORTHO: + self._norm_kwargs["norm"] = "ortho" + self._scale = np.sqrt(1 / self.nfft) + elif self.norm is _FFTNorms.NONE: + self._scale = self.nfft + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = 1.0 / self.nfft + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + if self.ifftshift_before: + x = _scipy_fft_backend.ifftshift(x, axes=self.axis) + if not self.clinear: + x = np.real(x) + if self.real: + y = _numpy_fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + y = np.swapaxes(y, -1, self.axis) + y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2) + y = np.swapaxes(y, self.axis, -1) + else: + y = _numpy_fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + if self.norm is _FFTNorms.ONE_OVER_N: + y *= self._scale + if self.fftshift_after: + y = _scipy_fft_backend.fftshift(y, axes=self.axis) + return y + + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: + if self.fftshift_after: + x = _scipy_fft_backend.ifftshift(x, axes=self.axis) + if self.real: + x = x.copy() + x = np.swapaxes(x, -1, self.axis) + x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2) + x = np.swapaxes(x, self.axis, -1) + y = _numpy_fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + else: + y = _numpy_fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs) + if self.norm is _FFTNorms.NONE: + y *= self._scale + + if self.nfft > self.dims[self.axis]: + y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis) + elif self.nfft < self.dims[self.axis]: + y = np.pad(y, self.ifftpad) + + if not self.clinear: + y = np.real(y) + if self.ifftshift_before: + y = _scipy_fft_backend.fftshift(y, axes=self.axis) + return y + + def __truediv__(self, y): + if self.norm is not _FFTNorms.ORTHO: + return self._rmatvec(y) / self._scale + return self._rmatvec(y) + + def FFT( dims: Union[int, InputDimsLike], axis: int = -1, @@ -395,6 +486,10 @@ def FFT( forward mode, and to :py:func:`scipy.fft.ifft` (or :py:func:`scipy.fft.irfft` for real models) in adjoint mode. + When the mkl_fft engine is chosen, the overloads are of :py:func: 'mkl_fft._numpy_fft.fft' + (or :py:func:`mkl_fft._numpy_fft.rfft` for real models) in forward mode and to :py:func:`mkl_fft._numpy_fft.ifft` + (or :py:func:`mkl_fft._numpy_fft.irfft`for real models) in adjoint mode. + When using ``real=True``, the result of the forward is also multiplied by :math:`\sqrt{2}` for all frequency bins except zero and Nyquist, and the input of the adjoint is multiplied by :math:`1 / \sqrt{2}` for the same frequencies. @@ -452,7 +547,7 @@ def FFT( frequencies are arranged from zero to largest positive, and then from negative Nyquist to the frequency bin before zero. engine : :obj:`str`, optional - Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose + Engine used for fft computation (``numpy``, ``fftw``, ``scipy`` or ``mkl_fft``). Choose ``numpy`` when working with cupy arrays. .. note:: Since version 1.17.0, accepts "scipy". @@ -505,7 +600,7 @@ def FFT( - If ``dims`` is provided and ``axis`` is bigger than ``len(dims)``. - If ``norm`` is not one of "ortho", "none", or "1/n". NotImplementedError - If ``engine`` is neither ``numpy``, ``fftw``, nor ``scipy``. + If ``engine`` is neither ``numpy``, ``fftw``, ``scipy`` nor ``mkl_fft``. See Also -------- @@ -550,7 +645,19 @@ def FFT( dtype=dtype, **kwargs_fftw, ) - elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None): + elif engine == "mkl_fft" and mkl_fft_message is None: + f = _FFT_mklfft( + dims, + axis=axis, + nfft=nfft, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + elif engine == "numpy" or (engine == "fftw" and pyfftw_message is not None) or (engine == "mkl_fft" and mkl_fft_message is not None): if engine == "fftw" and pyfftw_message is not None: logging.warning(pyfftw_message) f = _FFT_numpy( @@ -577,6 +684,6 @@ def FFT( dtype=dtype, ) else: - raise NotImplementedError("engine must be numpy, fftw or scipy") + raise NotImplementedError("engine must be numpy, fftw, scipy or mkl_fft") f.name = name return f diff --git a/pylops/signalprocessing/fft2d.py b/pylops/signalprocessing/fft2d.py index f54e2972..30938cda 100644 --- a/pylops/signalprocessing/fft2d.py +++ b/pylops/signalprocessing/fft2d.py @@ -11,6 +11,13 @@ from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike +from pylops.utils import deps +from pylops.signalprocessing.fftnd import _FFTND_mklfft + +mkl_fft_message = deps.mkl_fft_import("the mkl fft module") + +if mkl_fft_message is None: + from mkl_fft import _scipy_fft_backend logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -215,6 +222,100 @@ def __truediv__(self, y): return self._rmatvec(y) +class _FFT2D_mklfft(_BaseFFTND): + """Two-dimensional Fast-Fourier Transform using mkl_fft""" + + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (-2, -1), + nffts: Optional[Union[int, InputDimsLike]] = None, + sampling: Union[float, Sequence[float]] = 1.0, + norm: str = "ortho", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + ) -> None: + super().__init__( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + + # checks + if self.ndim < 2: + raise ValueError("FFT2D requires at least two input dimensions") + if self.naxes != 2: + raise ValueError("FFT2D must be applied along exactly two dimensions") + + self.f1, self.f2 = self.fs + del self.fs + + self._norm_kwargs: Dict[str, Union[None, str]] = { + "norm": None + } + if self.norm is _FFTNorms.ORTHO: + self._norm_kwargs["norm"] = "ortho" + self._scale = np.sqrt(1 / np.prod(np.sqrt(self.nffts))) + elif self.norm is _FFTNorms.NONE: + self._scale = np.sqrt(np.prod(self.nffts)) + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = np.sqrt(1.0 / np.prod(self.nffts)) + + @reshaped + def _matvec(self, x): + if self.ifftshift_before.any(): + x = _scipy_fft_backend.ifftshift(x, axes=self.axes[self.ifftshift_before]) + if not self.clinear: + x = np.real(x) + if self.real: + y = _FFTND_mklfft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + y = np.swapaxes(y, -1, self.axes[-1]) + y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) + y = np.swapaxes(y, self.axes[-1], -1) + else: + y = _FFTND_mklfft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + if self.norm is _FFTNorms.ONE_OVER_N: + y *= self._scale + if self.fftshift_after.any(): + y = _scipy_fft_backend.fftshift(y, axes=self.axes[self.fftshift_after]) + return y + + @reshaped + def _rmatvec(self, x): + if self.fftshift_after.any(): + x = _scipy_fft_backend.ifftshift(x, axes=self.axes[self.fftshift_after]) + if self.real: + x = x.copy() + x = np.swapaxes(x, -1, self.axes[-1]) + x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) + x = np.swapaxes(x, self.axes[-1], -1) + y = _FFTND_mklfft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + else: + y = _FFTND_mklfft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs) + if self.norm is _FFTNorms.NONE: + y *= self._scale + y = np.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0]) + y = np.take(y, range(self.dims[self.axes[1]]), axis=self.axes[1]) + if not self.clinear: + y = np.real(y) + if self.ifftshift_before.any(): + y = _scipy_fft_backend.fftshift(y, axes=self.axes[self.ifftshift_before]) + return y + + def __truediv__(self, y): + if self.norm is not _FFTNorms.ORTHO: + return self._rmatvec(y) / self._scale / self._scale + return self._rmatvec(y) + + def FFT2D( dims: InputDimsLike, axes: InputDimsLike = (-2, -1), @@ -242,6 +343,10 @@ def FFT2D( forward mode, and to :py:func:`scipy.fft.ifft2` (or :py:func:`scipy.fft.irfft2` for real models) in adjoint mode. + When the mkl_fft engine is chosen, the overloads are of :py:func: 'mkl_fft._numpy_fft.fft2' + (or :py:func:`mkl_fft._numpy_fft.fft2` for real models) in forward mode and to :py:func:`mkl_fft._numpy_fft.ifft2` + (or :py:func:`mkl_fft._numpy_fft.irfft2`for real models) in adjoint mode. + When using ``real=True``, the result of the forward is also multiplied by :math:`\sqrt{2}` for all frequency bins except zero and Nyquist, and the input of the adjoint is multiplied by :math:`1 / \sqrt{2}` for the same frequencies. @@ -310,7 +415,7 @@ def FFT2D( engine : :obj:`str`, optional .. versionadded:: 1.17.0 - Engine used for fft computation (``numpy`` or ``scipy``). + Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. @@ -361,7 +466,7 @@ def FFT2D( two elements. - If ``norm`` is not one of "ortho", "none", or "1/n". NotImplementedError - If ``engine`` is neither ``numpy``, nor ``scipy``. + If ``engine`` is neither ``numpy``, ``scipy`` nor ``mkl_fft``. See Also -------- @@ -394,7 +499,19 @@ def FFT2D( signals. """ - if engine == "numpy": + if engine == "mkl_fft" and mkl_fft_message is None: + f = _FFT2D_mklfft( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + elif engine == "numpy" or (engine == "mkl_fft" and mkl_fft_message is not None): f = _FFT2D_numpy( dims=dims, axes=axes, @@ -419,6 +536,6 @@ def FFT2D( dtype=dtype, ) else: - raise NotImplementedError("engine must be numpy or scipy") + raise NotImplementedError("engine must be numpy, scipy or mkl_fft") f.name = name return f diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py index a33f4918..1def1f62 100644 --- a/pylops/signalprocessing/fftnd.py +++ b/pylops/signalprocessing/fftnd.py @@ -11,6 +11,13 @@ from pylops.utils.backend import get_sp_fft from pylops.utils.decorators import reshaped from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray +from pylops.utils import deps + +mkl_fft_message = deps.mkl_fft_import("the mkl fft module") + +if mkl_fft_message is None: + from mkl_fft import _numpy_fft + from mkl_fft import _scipy_fft_backend logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING) @@ -197,6 +204,130 @@ def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: return self._rmatvec(y) +class _FFTND_mklfft(_BaseFFTND): + """N-dimensional Fast-Fourier Transform using mkl_fft""" + + def __init__( + self, + dims: Union[int, InputDimsLike], + axes: Union[int, InputDimsLike] = (-3, -2, -1), + nffts: Optional[Union[int, InputDimsLike]] = None, + sampling: Union[float, Sequence[float]] = 1.0, + norm: str = "ortho", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + ) -> None: + super().__init__( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + + self._norm_kwargs = {"norm": None} + if self.norm is _FFTNorms.ORTHO: + self._norm_kwargs["norm"] = "ortho" + self._scale = np.sqrt(1 / np.prod(self.nffts)) + elif self.norm is _FFTNorms.NONE: + self._scale = np.prod(self.nffts) + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = 1.0 / np.prod(self.nffts) + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + if self.ifftshift_before.any(): + x = _scipy_fft_backend.ifftshift(x, axes=self.axes[self.ifftshift_before]) + if not self.clinear: + x = np.real(x) + if self.real: + y = self.rfftn(x, s=self.nffts, axes=tuple(self.axes), **self._norm_kwargs) + y = np.swapaxes(y, -1, self.axes[-1]) + y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2) + y = np.swapaxes(y, self.axes[-1], -1) + else: + y = self.fftn(x, s=self.nffts, axes=tuple(self.axes), **self._norm_kwargs) + if self.norm is _FFTNorms.ONE_OVER_N: + y *= self._scale + if self.fftshift_after.any(): + y = _scipy_fft_backend.fftshift(y, axes=self.axes[self.fftshift_after]) + return y + + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: + if self.fftshift_after.any(): + x = _scipy_fft_backend.ifftshift(x, axes=self.axes[self.fftshift_after]) + if self.real: + x = x.copy() + x = np.swapaxes(x, -1, self.axes[-1]) + x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2) + x = np.swapaxes(x, self.axes[-1], -1) + y = self.irfftn(x, s=self.nffts, axes=tuple(self.axes), **self._norm_kwargs) + else: + y = self.ifftn(x, s=self.nffts, axes=tuple(self.axes), **self._norm_kwargs) + if self.norm is _FFTNorms.NONE: + y *= self._scale + for ax, nfft in zip(self.axes, self.nffts): + if nfft > self.dims[ax]: + y = np.take(y, range(self.dims[ax]), axis=ax) + if self.doifftpad: + y = np.pad(y, self.ifftpad) + if not self.clinear: + y = np.real(y) + if self.ifftshift_before.any(): + y = _scipy_fft_backend.fftshift(y, axes=self.axes[self.ifftshift_before]) + return y + + @staticmethod + def rfftn(a, s=None, axes=None, norm=None): + a = np.asarray(a) + s, axes = _numpy_fft._cook_nd_args(a, s, axes) + a = _numpy_fft.rfft(a, s[-1], axes[-1], norm) + for ii in range(len(axes) - 1): + a = _numpy_fft.fft(a, s[ii], axes[ii], norm) + return a + + @staticmethod + def irfftn(a, s=None, axes=None, norm=None): + a = np.asarray(a) + s, axes = _numpy_fft._cook_nd_args(a, s, axes, invreal=1) + for ii in range(len(axes) - 1): + a = _numpy_fft.ifft(a, s[ii], axes[ii], norm) + a = _numpy_fft.irfft(a, s[-1], axes[-1], norm) + return a + + @staticmethod + def fftn(a, s=None, axes=None, norm=None): + a = np.asarray(a) + s, axes = _numpy_fft._cook_nd_args(a, s, axes) + itl = list(range(len(axes))) + itl.reverse() + for ii in itl: + a = _numpy_fft.fft(a, n=s[ii], axis=axes[ii], norm=norm) + return a + + @staticmethod + def ifftn(a, s=None, axes=None, norm=None): + a = np.asarray(a) + s, axes = _numpy_fft._cook_nd_args(a, s, axes) + itl = list(range(len(axes))) + itl.reverse() + for ii in itl: + a = _numpy_fft.ifft(a, n=s[ii], axis=axes[ii], norm=norm) + return a + + def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike: + if self.norm is not _FFTNorms.ORTHO: + return self._rmatvec(y) / self._scale + return self._rmatvec(y) + + def FFTND( dims: Union[int, InputDimsLike], axes: Union[int, InputDimsLike] = (-3, -2, -1), @@ -224,6 +355,10 @@ def FFTND( forward mode, and to :py:func:`scipy.fft.ifftn` (or :py:func:`scipy.fft.irfftn` for real models) in adjoint mode. + When the mkl_fft engine is chosen, the overloads are of :py:func: `mkl_fft._numpy_fft.fftn` + (or :py:func:`mkl_fft._numpy_fft.rfftn` for real models) in forward mode and to :py:func:`mkl_fft._numpy_fft.ifftn` + (or :py:func:`mkl_fft._numpy_fft.irfftn`for real models) in adjoint mode. + When using ``real=True``, the result of the forward is also multiplied by :math:`\sqrt{2}` for all frequency bins except zero and Nyquist along the last ``axes``, and the input of the adjoint is multiplied by @@ -297,7 +432,7 @@ def FFTND( engine : :obj:`str`, optional .. versionadded:: 1.17.0 - Engine used for fft computation (``numpy`` or ``scipy``). + Engine used for fft computation (``numpy`` or ``scipy`` or ``mkl_fft``). dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. @@ -350,7 +485,7 @@ def FFTND( the same dimension ``axes``. - If ``norm`` is not one of "ortho", "none", or "1/n". NotImplementedError - If ``engine`` is neither ``numpy``, nor ``scipy``. + If ``engine`` is neither ``numpy``, ``scipy`` nor ``mkl_fft``. Notes ----- @@ -385,7 +520,19 @@ def FFTND( for real input signals. """ - if engine == "numpy": + if engine == "mkl_fft" and mkl_fft_message is None: + f = _FFTND_mklfft( + dims=dims, + axes=axes, + nffts=nffts, + sampling=sampling, + norm=norm, + real=real, + ifftshift_before=ifftshift_before, + fftshift_after=fftshift_after, + dtype=dtype, + ) + elif engine == "numpy" or (engine == "mkl_fft" and mkl_fft_message is not None): f = _FFTND_numpy( dims=dims, axes=axes, diff --git a/pylops/torchoperator.py b/pylops/torchoperator.py index b5a968a5..dfc3f4da 100644 --- a/pylops/torchoperator.py +++ b/pylops/torchoperator.py @@ -10,6 +10,8 @@ from pylops.utils import deps if deps.torch_enabled: + import torch + TensorTypeLike = torch.Tensor from pylops._torchoperator import _TorchOperator else: torch_message = ( @@ -17,7 +19,7 @@ 'the twoway module run "pip install torch" or' '"conda install -c pytorch torch".' ) -from pylops.utils.typing import TensorTypeLike + TensorTypeLike = None class TorchOperator(LinearOperator): diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index 7fad2838..12a9cc23 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -9,6 +9,7 @@ "spgl1_enabled", "sympy_enabled", "torch_enabled", + "mkl_fft_enabled" ] import os @@ -23,6 +24,7 @@ ) devito_enabled = util.find_spec("devito") is not None numba_enabled = util.find_spec("numba") is not None +mkl_fft_enabled = util.find_spec("mkl_fft") is not None pyfftw_enabled = util.find_spec("pyfftw") is not None pywt_enabled = util.find_spec("pywt") is not None skfmm_enabled = util.find_spec("skfmm") is not None @@ -87,6 +89,25 @@ def pyfftw_import(message): return pyfftw_message +def mkl_fft_import(message): + if mkl_fft_enabled: + try: + from mkl_fft import _scipy_fft_backend # noqa: F401 + from mkl_fft import _numpy_fft # noqa: F401 + mkl_fft_message = None + except Exception as e: + mkl_fft_message = f"Failed to import mkl_fft (error:{e}), use numpy." + else: + mkl_fft_message = ( + "mkl_fft not available, reverting to numpy. " + "In order to be able to use " + f"{message} run " + f'"pip install mkl_fft" or ' + f'"conda install -c conda-forge mkl_fft".' + ) + return mkl_fft_message + + def pywt_import(message): if pywt_enabled: try: diff --git a/pylops/utils/typing.py b/pylops/utils/typing.py index 191efbfa..8945cc5a 100644 --- a/pylops/utils/typing.py +++ b/pylops/utils/typing.py @@ -5,7 +5,6 @@ "SamplingLike", "ShapeLike", "DTypeLike", - "TensorTypeLike", ] from typing import Sequence, Tuple, Union @@ -13,11 +12,6 @@ import numpy as np import numpy.typing as npt -from pylops.utils.deps import torch_enabled - -if torch_enabled: - import torch - IntNDArray = npt.NDArray[np.int_] NDArray = npt.NDArray @@ -25,8 +19,3 @@ SamplingLike = Union[Sequence[float], NDArray] ShapeLike = Tuple[int, ...] DTypeLike = npt.DTypeLike - -if torch_enabled: - TensorTypeLike = torch.Tensor -else: - TensorTypeLike = None diff --git a/pytests/test_ffts.py b/pytests/test_ffts.py index 152f7573..f9083a2c 100755 --- a/pytests/test_ffts.py +++ b/pytests/test_ffts.py @@ -141,6 +141,59 @@ def _choose_random_axes(ndim, n_choices=2): "dtype": np.complex128, } # nfftnt, complex input, fftw engine +par3t = { + "nt": 41, + "nx": 31, + "ny": 10, + "nfft": None, + "real": True, + "engine": "mkl_fft", + "ifftshift_before": False, + "dtype": np.float64, +} # nfft=nt, real input, fftw engine +par4t = { + "nt": 41, + "nx": 31, + "ny": 10, + "nfft": 64, + "real": True, + "engine": "mkl_fft", + "ifftshift_before": False, + "dtype": np.float32, +} # nfft>nt, real input, fftw engine +par5t = { + "nt": 41, + "nx": 31, + "ny": 10, + "nfft": 16, + "real": False, + "engine": "mkl_fft", + "ifftshift_before": False, + "dtype": np.complex128, +} # nfft=1.21.0 scipy>=1.4.0 torch>=1.2.0 diff --git a/requirements-doc.txt b/requirements-doc.txt index d137b621..fd3f56f6 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,3 +1,6 @@ +mkl +mkl-fft +mkl-service numpy>=1.21.0 scipy>=1.4.0 torch diff --git a/setup.py b/setup.py index 8b82afa2..0e0e8eed 100755 --- a/setup.py +++ b/setup.py @@ -36,6 +36,9 @@ def src(pth): extras_require={ "advanced": [ "llvmlite", + "mkl", + "mkl_fft", + "mkl-service" "numba", "pyfftw", "PyWavelets",