Skip to content

Commit

Permalink
feat: added DTCWT 1d operator
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Feb 29, 2024
1 parent e588fe3 commit b56ab83
Show file tree
Hide file tree
Showing 11 changed files with 387 additions and 3 deletions.
13 changes: 12 additions & 1 deletion docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,20 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio
In alphabetic order:


dtcwt
-----
`dtcwt <https://dtcwt.readthedocs.io/en/0.12.0/>`_ is a library used to implement the DT-CWT operators.

Install it via ``pip`` with:

.. code-block:: bash
>> pip install dtcwt
Devito
------
`Devito <https://github.com/devitocodes/devito>`_ is library used to solve PDEs via
`Devito <https://github.com/devitocodes/devito>`_ is a library used to solve PDEs via
the finite-difference method. It is used in PyLops to compute wavefields
:py:class:`pylops.waveeqprocessing.AcousticWave2D`

Expand Down
1 change: 1 addition & 0 deletions environment-dev-arm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies:
- black
- pip:
- devito
- dtcwt
- scikit-fmm
- spgl1
- pytest-runner
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- black
- pip:
- devito
- dtcwt
- scikit-fmm
- spgl1
- pytest-runner
Expand Down
82 changes: 82 additions & 0 deletions examples/plot_dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Dual-Tree Complex Wavelet Transform
===================================
This example shows how to use the :py:class:`pylops.signalprocessing.DTCWT` operator to perform the
1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) input array. Such a transform
provides advantages over the DWT which lacks shift invariance in 1-D and directional sensitivity in N-D.
"""

import matplotlib.pyplot as plt
import numpy as np
import pywt

import pylops

plt.close("all")

###############################################################################
# To begin with, let's define two 1D arrays with a spike at slightly different location

n = 128
x = np.zeros(n)
x1 = np.zeros(n)

x[59] = 1
x1[63] = 1

###############################################################################
# We now create the DTCWT operator with the shape of our input array. The DTCWT transform
# provides a Pyramid object that is internally flattened out into a vector. Here we re-obtain
# the Pyramid object such that we can visualize the different scales indipendently.

level = 3
DCOp = pylops.signalprocessing.DTCWT(dims=n, level=level)
Xc = DCOp.get_pyramid(DCOp @ x)
Xc1 = DCOp.get_pyramid(DCOp @ x1)

###############################################################################
# To prove the superiority of the DTCWT transform over the DWT in shift-invariance,
# let's also compute the DWT transform of these two signals and compare the coefficents
# of both transform at level 3. As you will see, the coefficients change completely for
# the DWT despite the two input signals are very similar; this is not the case for the
# DCWT transform.

DOp = pylops.signalprocessing.DWT(dims=n, level=level, wavelet="sym7")
X = pywt.array_to_coeffs(DOp @ x, DOp.sl, output_format="wavedecn")
X1 = pywt.array_to_coeffs(DOp @ x1, DOp.sl, output_format="wavedecn")

fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 5))
axs[0, 0].stem(np.abs(X[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k")
axs[0, 0].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X[1]['d']))**2:.3f})")
axs[0, 1].stem(np.abs(X1[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k")
axs[0, 1].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X1[1]['d']))**2:.3f})")
axs[1, 0].stem(np.abs(Xc.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k")
axs[1, 0].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc.highpasses[2]))**2:.3f})")
axs[1, 1].stem(np.abs(Xc1.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k")
axs[1, 1].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc1.highpasses[2]))**2:.3f})")
plt.tight_layout()

###################################################################################
# The DTCWT can also be performed on multi-dimension arrays, where the parameter
# ``axis`` is used to define the axis over which the transform is performed. Let's
# just replicate our input signal over the second axis and see how the transform
# will produce the same series of coefficients for all replicas.

nrepeat = 10
x = np.repeat(np.random.rand(n, 1), 10, axis=1).T

level = 3
DCOp = pylops.signalprocessing.DTCWT(dims=(nrepeat, n), level=level, axis=1)
X = DCOp @ x

fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 3))
axs[0].imshow(X[0])
axs[0].axis("tight")
axs[0].set_xlabel("Coeffs")
axs[0].set_ylabel("Replicas")
axs[0].set_title("DTCWT Real")
axs[1].imshow(X[1])
axs[1].axis("tight")
axs[1].set_xlabel("Coeffs")
axs[1].set_title("DTCWT Imag")
plt.tight_layout()
6 changes: 5 additions & 1 deletion pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DCT Discrete Cosine Transform.
Seislet Two dimensional Seislet operator.
DTCWT Dual-Tree Complex Wavelet Transform.
Radon2D Two dimensional Radon transform.
Radon3D Three dimensional Radon transform.
Seislet Two dimensional Seislet operator.
Sliding1D 1D Sliding transform operator.
Sliding2D 2D Sliding transform operator.
Sliding3D 3D Sliding transform operator.
Expand Down Expand Up @@ -62,6 +63,8 @@
from .dwt2d import *
from .seislet import *
from .dct import *
from .dtcwt import *


__all__ = [
"FFT",
Expand Down Expand Up @@ -92,4 +95,5 @@
"DWT2D",
"Seislet",
"DCT",
"DTCWT",
]
2 changes: 1 addition & 1 deletion pylops/signalprocessing/dct.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DCT(LinearOperator):
axes : :obj:`int` or :obj:`list`, optional
Axes over which the DCT is computed. If ``None``, the transform is applied
over all axes.
workers :obj:`int`, optional
workers : :obj:`int`, optional
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from os.cpu_count().
dtype : :obj:`DTypeLike`, optional
Expand Down
182 changes: 182 additions & 0 deletions pylops/signalprocessing/dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
__all__ = ["DTCWT"]

from typing import Union

import numpy as np

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

dtcwt_message = deps.dtcwt_import("the dtcwt module")

if dtcwt_message is None:
import dtcwt


class DTCWT(LinearOperator):
r"""Dual-Tree Complex Wavelet Transform
Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a
multi-dimensional array of size ``dims``.
Note that the DTCWT operator is an overload of the ``dtcwt``
implementation of the DT-CWT transform. Refer to
https://dtcwt.readthedocs.io for a detailed description of the
input parameters.
Parameters
----------
dims : :obj:`int` or :obj:`tuple`
Number of samples for each dimension.
birot : :obj:`str`, optional
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`.
qshift : :obj:`str`, optional
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"`
level : :obj:`int`, optional
Number of levels of wavelet decomposition. Default is 3.
include_scale : :obj:`bool`, optional
Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False.
axis : :obj:`int`, optional
Axis on which the transform is performed.
dtype : :obj:`DTypeLike`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
Notes
-----
The DTCWT operator applies the dual-tree complex wavelet transform
in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode
from the ``dtcwt`` library.
The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain,
which is composed of:
- `lowpass` (coarsest scale lowpass signal);
- `highpasses` (complex subband coefficients for corresponding scales);
- `scales` (lowpass signal for corresponding scales finest to coarsest).
To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model
the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients)
are appended into one array using the `_coeff_to_array` method.
In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff`
method and then the inverse transform is performed.
"""

def __init__(
self,
dims: Union[int, InputDimsLike],
biort: str = "near_sym_a",
qshift: str = "qshift_a",
level: int = 3,
include_scale: bool = False,
axis: int = -1,
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
if dtcwt_message is not None:
raise NotImplementedError(dtcwt_message)

dims = _value_or_sized_to_tuple(dims)
self.ndim = len(dims)
self.axis = axis

self.otherdims = int(np.prod(dims) / dims[self.axis])
self.dims_swapped = list(dims)
self.dims_swapped[0], self.dims_swapped[self.axis] = (
self.dims_swapped[self.axis],
self.dims_swapped[0],
)
self.dims_swapped = tuple(self.dims_swapped)
self.level = level
self.include_scale = include_scale

# dry-run of transform to find dimensions of coefficients at different levels
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift)
self._interpret_coeffs(dims, self.axis)

dimsd = list(dims)
dimsd[self.axis] = self.coeff_array_size
self.dimsd_swapped = list(dimsd)
self.dimsd_swapped[0], self.dimsd_swapped[self.axis] = (
self.dimsd_swapped[self.axis],
self.dimsd_swapped[0],
)
self.dimsd_swapped = tuple(self.dimsd_swapped)
dimsd = tuple(
[
2,
]
+ dimsd
)

super().__init__(
dtype=np.dtype(dtype),
clinear=False,
dims=dims,
dimsd=dimsd,
name=name,
)

def _interpret_coeffs(self, dims, axis):
x = np.ones(dims[axis])
pyr = self._transform.forward(
x, nlevels=self.level, include_scale=self.include_scale
)
self.lowpass_size = pyr.lowpass.size
self.coeff_array_size = self.lowpass_size
self.highpass_sizes = []
for _h in pyr.highpasses:
self.highpass_sizes.append(_h.size)
self.coeff_array_size += _h.size

def _nd_to_2d(self, arr_nd):
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
return arr_2d

def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray:
highpass_coeffs = np.vstack([h for h in pyr.highpasses])
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
return coeffs

def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid:
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
_ptr = 0
highpasses = ()
for _sl in self.highpass_sizes:
_h = X[_ptr : _ptr + _sl]
_ptr += _sl
_h = _h.reshape(-1, self.otherdims)
highpasses += (_h,)
return dtcwt.Pyramid(lowpass, highpasses)

def get_pyramid(self, x: NDArray) -> dtcwt.Pyramid:
"""Return Pyramid object from flat real-valued array"""
return self._array_to_coeff(x[0] + 1j * x[1])

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
x = x.swapaxes(self.axis, 0)
y = self._nd_to_2d(x)
y = self._coeff_to_array(
self._transform.forward(
y, nlevels=self.level, include_scale=self.include_scale
)
)
y = y.reshape(self.dimsd_swapped)
y = y.swapaxes(self.axis, 0)
y = np.concatenate([y.real[np.newaxis], y.imag[np.newaxis]])
return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
x = x[0] + 1j * x[1]
x = x.swapaxes(self.axis, 0)
y = self._transform.inverse(self._array_to_coeff(x))
y = y.reshape(self.dims_swapped)
y = y.swapaxes(self.axis, 0)
return y
19 changes: 19 additions & 0 deletions pylops/utils/deps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = [
"cupy_enabled",
"devito_enabled",
"dtcwt_enabled",
"numba_enabled",
"pyfftw_enabled",
"pywt_enabled",
Expand Down Expand Up @@ -67,6 +68,23 @@ def devito_import(message: Optional[str] = None) -> str:
return devito_message


def dtcwt_import(message: Optional[str] = None) -> str:
if dtcwt_enabled:
try:
import dtcwt # noqa: F401

dtcwt_message = None
except Exception as e:
dtcwt_message = f"Failed to import dtcwt (error:{e})."
else:
dtcwt_message = (
f"Dtcwt not available. "
f"In order to be able to use "
f'{message} run "pip install dtcwt".'
)
return dtcwt_message


def numba_import(message: Optional[str] = None) -> str:
if numba_enabled:
try:
Expand Down Expand Up @@ -187,6 +205,7 @@ def sympy_import(message: Optional[str] = None) -> str:
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
devito_enabled = util.find_spec("devito") is not None
dtcwt_enabled = util.find_spec("dtcwt") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ advanced = [
"PyWavelets",
"scikit-fmm",
"spgl1",
"dtcwt",
]

[tool.setuptools.packages.find]
Expand Down
Loading

0 comments on commit b56ab83

Please sign in to comment.