-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #569 from mrava87/feature-dtcwtclean
feat: added DTCWT 1d operator
- Loading branch information
Showing
11 changed files
with
387 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ dependencies: | |
- black | ||
- pip: | ||
- devito | ||
- dtcwt | ||
- scikit-fmm | ||
- spgl1 | ||
- pytest-runner | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ dependencies: | |
- black | ||
- pip: | ||
- devito | ||
- dtcwt | ||
- scikit-fmm | ||
- spgl1 | ||
- pytest-runner | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
Dual-Tree Complex Wavelet Transform | ||
=================================== | ||
This example shows how to use the :py:class:`pylops.signalprocessing.DTCWT` operator to perform the | ||
1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) input array. Such a transform | ||
provides advantages over the DWT which lacks shift invariance in 1-D and directional sensitivity in N-D. | ||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pywt | ||
|
||
import pylops | ||
|
||
plt.close("all") | ||
|
||
############################################################################### | ||
# To begin with, let's define two 1D arrays with a spike at slightly different location | ||
|
||
n = 128 | ||
x = np.zeros(n) | ||
x1 = np.zeros(n) | ||
|
||
x[59] = 1 | ||
x1[63] = 1 | ||
|
||
############################################################################### | ||
# We now create the DTCWT operator with the shape of our input array. The DTCWT transform | ||
# provides a Pyramid object that is internally flattened out into a vector. Here we re-obtain | ||
# the Pyramid object such that we can visualize the different scales indipendently. | ||
|
||
level = 3 | ||
DCOp = pylops.signalprocessing.DTCWT(dims=n, level=level) | ||
Xc = DCOp.get_pyramid(DCOp @ x) | ||
Xc1 = DCOp.get_pyramid(DCOp @ x1) | ||
|
||
############################################################################### | ||
# To prove the superiority of the DTCWT transform over the DWT in shift-invariance, | ||
# let's also compute the DWT transform of these two signals and compare the coefficents | ||
# of both transform at level 3. As you will see, the coefficients change completely for | ||
# the DWT despite the two input signals are very similar; this is not the case for the | ||
# DCWT transform. | ||
|
||
DOp = pylops.signalprocessing.DWT(dims=n, level=level, wavelet="sym7") | ||
X = pywt.array_to_coeffs(DOp @ x, DOp.sl, output_format="wavedecn") | ||
X1 = pywt.array_to_coeffs(DOp @ x1, DOp.sl, output_format="wavedecn") | ||
|
||
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 5)) | ||
axs[0, 0].stem(np.abs(X[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k") | ||
axs[0, 0].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X[1]['d']))**2:.3f})") | ||
axs[0, 1].stem(np.abs(X1[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k") | ||
axs[0, 1].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X1[1]['d']))**2:.3f})") | ||
axs[1, 0].stem(np.abs(Xc.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k") | ||
axs[1, 0].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc.highpasses[2]))**2:.3f})") | ||
axs[1, 1].stem(np.abs(Xc1.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k") | ||
axs[1, 1].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc1.highpasses[2]))**2:.3f})") | ||
plt.tight_layout() | ||
|
||
################################################################################### | ||
# The DTCWT can also be performed on multi-dimension arrays, where the parameter | ||
# ``axis`` is used to define the axis over which the transform is performed. Let's | ||
# just replicate our input signal over the second axis and see how the transform | ||
# will produce the same series of coefficients for all replicas. | ||
|
||
nrepeat = 10 | ||
x = np.repeat(np.random.rand(n, 1), 10, axis=1).T | ||
|
||
level = 3 | ||
DCOp = pylops.signalprocessing.DTCWT(dims=(nrepeat, n), level=level, axis=1) | ||
X = DCOp @ x | ||
|
||
fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 3)) | ||
axs[0].imshow(X[0]) | ||
axs[0].axis("tight") | ||
axs[0].set_xlabel("Coeffs") | ||
axs[0].set_ylabel("Replicas") | ||
axs[0].set_title("DTCWT Real") | ||
axs[1].imshow(X[1]) | ||
axs[1].axis("tight") | ||
axs[1].set_xlabel("Coeffs") | ||
axs[1].set_title("DTCWT Imag") | ||
plt.tight_layout() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
__all__ = ["DTCWT"] | ||
|
||
from typing import Union | ||
|
||
import numpy as np | ||
|
||
from pylops import LinearOperator | ||
from pylops.utils import deps | ||
from pylops.utils._internal import _value_or_sized_to_tuple | ||
from pylops.utils.decorators import reshaped | ||
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray | ||
|
||
dtcwt_message = deps.dtcwt_import("the dtcwt module") | ||
|
||
if dtcwt_message is None: | ||
import dtcwt | ||
|
||
|
||
class DTCWT(LinearOperator): | ||
r"""Dual-Tree Complex Wavelet Transform | ||
Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a | ||
multi-dimensional array of size ``dims``. | ||
Note that the DTCWT operator is an overload of the ``dtcwt`` | ||
implementation of the DT-CWT transform. Refer to | ||
https://dtcwt.readthedocs.io for a detailed description of the | ||
input parameters. | ||
Parameters | ||
---------- | ||
dims : :obj:`int` or :obj:`tuple` | ||
Number of samples for each dimension. | ||
birot : :obj:`str`, optional | ||
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`. | ||
qshift : :obj:`str`, optional | ||
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"` | ||
level : :obj:`int`, optional | ||
Number of levels of wavelet decomposition. Default is 3. | ||
include_scale : :obj:`bool`, optional | ||
Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False. | ||
axis : :obj:`int`, optional | ||
Axis on which the transform is performed. | ||
dtype : :obj:`DTypeLike`, optional | ||
Type of elements in input array. | ||
name : :obj:`str`, optional | ||
Name of operator (to be used by :func:`pylops.utils.describe.describe`) | ||
Notes | ||
----- | ||
The DTCWT operator applies the dual-tree complex wavelet transform | ||
in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode | ||
from the ``dtcwt`` library. | ||
The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain, | ||
which is composed of: | ||
- `lowpass` (coarsest scale lowpass signal); | ||
- `highpasses` (complex subband coefficients for corresponding scales); | ||
- `scales` (lowpass signal for corresponding scales finest to coarsest). | ||
To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model | ||
the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients) | ||
are appended into one array using the `_coeff_to_array` method. | ||
In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff` | ||
method and then the inverse transform is performed. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dims: Union[int, InputDimsLike], | ||
biort: str = "near_sym_a", | ||
qshift: str = "qshift_a", | ||
level: int = 3, | ||
include_scale: bool = False, | ||
axis: int = -1, | ||
dtype: DTypeLike = "float64", | ||
name: str = "C", | ||
) -> None: | ||
if dtcwt_message is not None: | ||
raise NotImplementedError(dtcwt_message) | ||
|
||
dims = _value_or_sized_to_tuple(dims) | ||
self.ndim = len(dims) | ||
self.axis = axis | ||
|
||
self.otherdims = int(np.prod(dims) / dims[self.axis]) | ||
self.dims_swapped = list(dims) | ||
self.dims_swapped[0], self.dims_swapped[self.axis] = ( | ||
self.dims_swapped[self.axis], | ||
self.dims_swapped[0], | ||
) | ||
self.dims_swapped = tuple(self.dims_swapped) | ||
self.level = level | ||
self.include_scale = include_scale | ||
|
||
# dry-run of transform to find dimensions of coefficients at different levels | ||
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) | ||
self._interpret_coeffs(dims, self.axis) | ||
|
||
dimsd = list(dims) | ||
dimsd[self.axis] = self.coeff_array_size | ||
self.dimsd_swapped = list(dimsd) | ||
self.dimsd_swapped[0], self.dimsd_swapped[self.axis] = ( | ||
self.dimsd_swapped[self.axis], | ||
self.dimsd_swapped[0], | ||
) | ||
self.dimsd_swapped = tuple(self.dimsd_swapped) | ||
dimsd = tuple( | ||
[ | ||
2, | ||
] | ||
+ dimsd | ||
) | ||
|
||
super().__init__( | ||
dtype=np.dtype(dtype), | ||
clinear=False, | ||
dims=dims, | ||
dimsd=dimsd, | ||
name=name, | ||
) | ||
|
||
def _interpret_coeffs(self, dims, axis): | ||
x = np.ones(dims[axis]) | ||
pyr = self._transform.forward( | ||
x, nlevels=self.level, include_scale=self.include_scale | ||
) | ||
self.lowpass_size = pyr.lowpass.size | ||
self.coeff_array_size = self.lowpass_size | ||
self.highpass_sizes = [] | ||
for _h in pyr.highpasses: | ||
self.highpass_sizes.append(_h.size) | ||
self.coeff_array_size += _h.size | ||
|
||
def _nd_to_2d(self, arr_nd): | ||
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze() | ||
return arr_2d | ||
|
||
def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: | ||
highpass_coeffs = np.vstack([h for h in pyr.highpasses]) | ||
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0) | ||
return coeffs | ||
|
||
def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: | ||
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims)) | ||
_ptr = 0 | ||
highpasses = () | ||
for _sl in self.highpass_sizes: | ||
_h = X[_ptr : _ptr + _sl] | ||
_ptr += _sl | ||
_h = _h.reshape(-1, self.otherdims) | ||
highpasses += (_h,) | ||
return dtcwt.Pyramid(lowpass, highpasses) | ||
|
||
def get_pyramid(self, x: NDArray) -> dtcwt.Pyramid: | ||
"""Return Pyramid object from flat real-valued array""" | ||
return self._array_to_coeff(x[0] + 1j * x[1]) | ||
|
||
@reshaped | ||
def _matvec(self, x: NDArray) -> NDArray: | ||
x = x.swapaxes(self.axis, 0) | ||
y = self._nd_to_2d(x) | ||
y = self._coeff_to_array( | ||
self._transform.forward( | ||
y, nlevels=self.level, include_scale=self.include_scale | ||
) | ||
) | ||
y = y.reshape(self.dimsd_swapped) | ||
y = y.swapaxes(self.axis, 0) | ||
y = np.concatenate([y.real[np.newaxis], y.imag[np.newaxis]]) | ||
return y | ||
|
||
@reshaped | ||
def _rmatvec(self, x: NDArray) -> NDArray: | ||
x = x[0] + 1j * x[1] | ||
x = x.swapaxes(self.axis, 0) | ||
y = self._transform.inverse(self._array_to_coeff(x)) | ||
y = y.reshape(self.dims_swapped) | ||
y = y.swapaxes(self.axis, 0) | ||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ advanced = [ | |
"PyWavelets", | ||
"scikit-fmm", | ||
"spgl1", | ||
"dtcwt", | ||
] | ||
|
||
[tool.setuptools.packages.find] | ||
|
Oops, something went wrong.