Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: N-dimensional discrete wavelet transforms #583

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Signal processing
Shift
DWT
DWT2D
DWTND
DCT
DTCWT
Seislet
Expand Down
47 changes: 47 additions & 0 deletions examples/plot_wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,50 @@
axs[1, 1].set_title("DWT2 coefficients (zeroed)")
axs[1, 1].axis("tight")
plt.tight_layout()

###############################################################################
mrava87 marked this conversation as resolved.
Show resolved Hide resolved
# Let us now try the same with a 3D volumetric model, where we use the
# N-dimensional DWT. Again, we only retain a quarter of the coefficients of
# the DWT.

nx = 128
ny = 256
nz = 128

x = np.arange(nx)
y = np.arange(ny)
z = np.arange(nz)

xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
# Generate a 3D model with two block anomalies
m = np.ones_like(xx, dtype=float)
block1 = (xx > 10) & (xx < 60) & (yy > 100) & (yy < 150) & (zz > 20) & (zz < 70)
block2 = (xx > 70) & (xx < 80) & (yy > 100) & (yy < 200) & (zz > 10) & (zz < 50)
m[block1] = 1.2
m[block2] = 0.8
Wop = pylops.signalprocessing.DWTND((nx, ny, nz), wavelet="haar", level=3)
y = Wop * m

yf = y.copy()
yf.flat[y.size // 4 :] = 0
iminv = Wop.H * yf

ratio = 0.1
mrava87 marked this conversation as resolved.
Show resolved Hide resolved
yf = y.copy()
yf.flat[int(ratio * y.size) :] = 0
iminv = Wop.H * yf

fig, axs = plt.subplots(2, 2, figsize=(6, 6))
axs[0, 0].imshow(m[:, :, 30], cmap="gray")
axs[0, 0].set_title("Model (Slice at z=30)")
axs[0, 0].axis("tight")
axs[0, 1].imshow(y[:, :, 90], cmap="gray_r")
axs[0, 1].set_title("DWTNT coefficients")
axs[0, 1].axis("tight")
axs[1, 0].imshow(iminv[:, :, 30], cmap="gray")
axs[1, 0].set_title("Reconstructed model (Slice at z=30)")
axs[1, 0].axis("tight")
axs[1, 1].imshow(yf[:, :, 90], cmap="gray_r")
axs[1, 1].set_title("DWTNT coefficients (zeroed)")
axs[1, 1].axis("tight")
plt.tight_layout()
3 changes: 3 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Shift Fractional Shift operator.
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DWTND N-dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transform.
Radon2D Two dimensional Radon transform.
Expand Down Expand Up @@ -61,6 +62,7 @@
from .fredholm1 import *
from .dwt import *
from .dwt2d import *
from .dwtnd import *
from .seislet import *
from .dct import *
from .dtcwt import *
Expand Down Expand Up @@ -93,6 +95,7 @@
"Fredholm1",
"DWT",
"DWT2D",
"DWTND",
"Seislet",
"DCT",
"DTCWT",
Expand Down
142 changes: 142 additions & 0 deletions pylops/signalprocessing/dwtnd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
__all__ = ["DWTND"]

import logging
from math import ceil, log

import numpy as np

from pylops import LinearOperator
from pylops.basicoperators import Pad
from pylops.utils import deps
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

from .dwt import _adjointwavelet, _checkwavelet

pywt_message = deps.pywt_import("the dwtnd module")

if pywt_message is None:
import pywt

logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)


class DWTND(LinearOperator):
"""N-dimensional Wavelet operator.

Apply ND-Wavelet transform along N ``axes`` of a
multi-dimensional array of size ``dims``.

Note that the Wavelet operator is an overload of the ``pywt``
implementation of the wavelet transform. Refer to
https://pywavelets.readthedocs.io for a detailed description of the
input parameters.

Defaults to a 3D wavelet transform along the last three dimensions
of the input array.

Parameters
----------
dims : :obj:`tuple`
Number of samples for each dimension
axes : :obj:`int`, optional
.. versionadded:: 2.0.0
mrava87 marked this conversation as resolved.
Show resolved Hide resolved

Axis along which DWTND is applied
wavelet : :obj:`str`, optional
Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for
a list of available wavelets.
level : :obj:`int`, optional
Number of scaling levels (must be >=0).
dtype : :obj:`str`, optional
Type of elements in input array.
name : :obj:`str`, optional
.. versionadded:: 2.0.0
mrava87 marked this conversation as resolved.
Show resolved Hide resolved

Name of operator (to be used by :func:`pylops.utils.describe.describe`)

Attributes
----------
shape : :obj:`tuple`
Operator shape
explicit : :obj:`bool`
Operator contains a matrix that can be solved explicitly
(``True``) or not (``False``)

Raises
------
ModuleNotFoundError
If ``pywt`` is not installed
ValueError
If ``wavelet`` does not belong to ``pywt.families``

Notes
-----
The Wavelet operator applies the N-dimensional multilevel Discrete
Wavelet Transform (DWTN) in forward mode and the N-dimensional multilevel
Inverse Discrete Wavelet Transform (IDWTN) in adjoint mode.

"""

def __init__(
self,
dims: InputDimsLike,
axes: InputDimsLike = (-3, -2, -1),
wavelet: str = "haar",
level: int = 1,
dtype: DTypeLike = "float64",
name: str = "D",
) -> None:
if pywt_message is not None:
raise ModuleNotFoundError(pywt_message)
_checkwavelet(wavelet)

# define padding for length to be power of 2
ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2**level) for ax in axes]
pad = [(0, 0)] * len(dims)
for i, ax in enumerate(axes):
pad[ax] = (0, ndimpow2[i] - dims[ax])
self.pad = Pad(dims, pad)
self.axes = axes
dimsd = list(dims)
for i, ax in enumerate(axes):
dimsd[ax] = ndimpow2[i]
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)

# apply transform once again to find out slices
_, self.sl = pywt.coeffs_to_array(
pywt.wavedecn(
np.ones(self.dimsd),
wavelet=wavelet,
level=level,
mode="periodization",
axes=self.axes,
),
axes=self.axes,
)
self.wavelet = wavelet
self.waveletadj = _adjointwavelet(wavelet)
self.level = level

def _matvec(self, x: NDArray) -> NDArray:
x = self.pad.matvec(x)
x = np.reshape(x, self.dimsd)
y = pywt.coeffs_to_array(
pywt.wavedecn(
x,
wavelet=self.wavelet,
level=self.level,
mode="periodization",
axes=self.axes,
),
axes=(self.axes),
)[0]
return y.ravel()

def _rmatvec(self, x: NDArray) -> NDArray:
x = np.reshape(x, self.dimsd)
x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn")
y = pywt.waverecn(
x, wavelet=self.waveletadj, mode="periodization", axes=self.axes
)
y = self.pad.rmatvec(y.ravel())
return y
64 changes: 63 additions & 1 deletion pytests/test_dwts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
from numpy.testing import assert_array_almost_equal
from scipy.sparse.linalg import lsqr

from pylops.signalprocessing import DWT, DWT2D
from pylops.signalprocessing import DWT, DWT2D, DWTND
from pylops.utils import dottest

par1 = {"ny": 7, "nx": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real
par2 = {"ny": 7, "nx": 9, "nt": 10, "imag": 1j, "dtype": "complex64"} # complex
par3 = {"ny": 7, "nx": 9, "nz": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real 4D
mrava87 marked this conversation as resolved.
Show resolved Hide resolved
par4 = {
"ny": 7,
"nx": 9,
"nz": 9,
"nt": 10,
"imag": 1j,
"dtype": "complex64",
} # complex 4D

np.random.seed(10)

Expand Down Expand Up @@ -133,3 +142,56 @@ def test_DWT2D_3dsignal(par):

assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)


@pytest.mark.parametrize("par", [(par3), (par4)])
def test_DWTND_3dsignal(par):
"""Dot-test and inversion for DWT2D operator for 3d signal"""
mrava87 marked this conversation as resolved.
Show resolved Hide resolved
DWTop = DWTND(
dims=(par["nt"], par["nx"], par["ny"]), axes=(0, 1, 2), wavelet="haar", level=3
)
x = np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + par[
"imag"
] * np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"]))

assert dottest(
DWTop, DWTop.shape[0], DWTop.shape[1], complexflag=0 if par["imag"] == 0 else 3
)

y = DWTop * x.ravel()
xadj = DWTop.H * y # adjoint is same as inverse for dwt
xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]

assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)


@pytest.mark.parametrize("par", [(par3), (par4)])
def test_DWTND_4dsignal(par):
"""Dot-test and inversion for DWT operator for 4d signal"""
for axes in [(0, 1, 2), (0, 2, 3), (1, 2, 3), (0, 1, 3), (0, 1, 2, 3)]:
DWTop = DWTND(
dims=(par["nt"], par["nx"], par["ny"], par["nz"]),
axes=axes,
wavelet="haar",
level=3,
)
x = np.random.normal(
0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
) + par["imag"] * np.random.normal(
0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
)

assert dottest(
DWTop,
DWTop.shape[0],
DWTop.shape[1],
complexflag=0 if par["imag"] == 0 else 3,
)

y = DWTop * x.ravel()
xadj = DWTop.H * y # adjoint is same as inverse for dwt
xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]

assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)