-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added DTCWT
operator
#495
Added DTCWT
operator
#495
Changes from 6 commits
52e6c12
4b7ec1e
f887031
7435b51
22e6780
60f9ca0
b0e37c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,7 @@ Signal processing | |
DWT | ||
DWT2D | ||
DCT | ||
DTCWT | ||
Seislet | ||
Radon2D | ||
Radon3D | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Dual-Tree Complex Wavelet Transform | ||
========================= | ||
This example shows how to use the :py:class:`pylops.signalprocessing.DCT` operator. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be pylops.signalprocessing.DTCWT |
||
This operator performs the 1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) | ||
input array. | ||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
import pylops | ||
|
||
plt.close("all") | ||
|
||
############################################################################### | ||
# Let's define a 1D array x of having random values | ||
|
||
n = 50 | ||
x = np.random.rand(n,) | ||
|
||
############################################################################### | ||
# We create the DTCWT operator with shape of our input array. DTCWT transform | ||
# gives a Pyramid object that is flattened out `y`. | ||
|
||
DOp = pylops.signalprocessing.DTCWT(dims=x.shape) | ||
y = DOp @ x | ||
xadj = DOp.H @ y | ||
|
||
plt.figure(figsize=(8, 5)) | ||
plt.plot(x, "k", label="input array") | ||
plt.plot(y, "r", label="transformed array") | ||
plt.plot(xadj, "--b", label="transformed array") | ||
plt.title("Dual-Tree Complex Wavelet Transform 1D") | ||
plt.legend() | ||
plt.tight_layout() | ||
|
||
################################################################################# | ||
# To get the Pyramid object use the `get_pyramid` method. | ||
# We can get the Highpass signal and Lowpass signal from it | ||
|
||
pyr = DOp.get_pyramid(y) | ||
|
||
plt.figure(figsize=(10, 5)) | ||
plt.plot(x, "--b", label="orignal signal") | ||
plt.plot(pyr.lowpass, "k", label="lowpass") | ||
plt.plot(pyr.highpasses[0], "r", label="highpass level 1 signal") | ||
plt.plot(pyr.highpasses[1], "b", label="highpass level 2 signal") | ||
plt.plot(pyr.highpasses[2], "g", label="highpass level 3 signal") | ||
|
||
plt.title("DTCWT Pyramid Object") | ||
plt.legend() | ||
plt.tight_layout() | ||
|
||
################################################################################### | ||
# DTCWT can also be performed on multi-dimension arrays. The number of levels can also | ||
# be defined using the `nlevels` | ||
|
||
n = 10 | ||
m = 2 | ||
|
||
x = np.random.rand(n, m) | ||
|
||
DOp = pylops.signalprocessing.DTCWT(dims=x.shape, nlevels=5) | ||
y = DOp @ x | ||
xadj = DOp.H @ y | ||
|
||
plt.figure(figsize=(8, 5)) | ||
plt.plot(x, "k", label="input array") | ||
plt.plot(y, "r", label="transformed array") | ||
plt.plot(xadj, "--b", label="transformed array") | ||
plt.title("Dual-Tree Complex Wavelet Transform 1D on ND array") | ||
plt.legend() | ||
plt.tight_layout() |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,147 @@ | ||||
__all__ = ["DTCWT"] | ||||
|
||||
from typing import Union | ||||
|
||||
import dtcwt | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs to be wrapped in a https://github.com/PyLops/pylops/blob/dev/pylops/_torchoperator.py#LL5-LL10 |
||||
import numpy as np | ||||
|
||||
from pylops import LinearOperator | ||||
from pylops.utils._internal import _value_or_sized_to_tuple | ||||
from pylops.utils.decorators import reshaped | ||||
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray | ||||
|
||||
|
||||
class DTCWT(LinearOperator): | ||||
r"""Dual-Tree Complex Wavelet Transform | ||||
Perform 1D Dual-Tree Complex Wavelet Transform on a given array. | ||||
|
||||
This operator wraps around :py:func:`dtcwt` package. | ||||
|
||||
Parameters | ||||
---------- | ||||
dims: :obj:`int` or :obj:`tuple` | ||||
Number of samples for each dimension. | ||||
birot: :obj:`str`, optional | ||||
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`. | ||||
qshift: :obj:`str`, optional | ||||
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift()`. Default is `"qshift_a"` | ||||
nlevels: :obj:`int`, optional | ||||
Number of levels of wavelet decomposition. Default is 3. | ||||
include_scale: :obj:`bool`, optional | ||||
Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False. | ||||
axis: :obj:`int`, optional | ||||
Axis on which the transform is performed. Default is -1. | ||||
dtype : :obj:`DTypeLike`, optional | ||||
Type of elements in input array. | ||||
name : :obj:`str`, optional | ||||
Name of operator (to be used by :func:`pylops.utils.describe.describe`) | ||||
|
||||
|
||||
Notes | ||||
----- | ||||
The :py:func:`dtcwt` library uses a Pyramid object to represent the transformed domain signal. | ||||
It has | ||||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
- `lowpass` (coarsest scale lowpass signal) | ||||
- `highpasses` (complex subband coefficients for corresponding scales) | ||||
- `scales` (lowpass signal for corresponding scales finest to coarsest) | ||||
|
||||
To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is | ||||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
flattened out and all coefficents (high-pass and low pass coefficients) are appended into one array using the | ||||
`_coeff_to_array` method. | ||||
For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff` | ||||
method and then inverse is performed. | ||||
""" | ||||
|
||||
def __init__( | ||||
self, | ||||
dims: Union[int, InputDimsLike], | ||||
biort: str = "near_sym_a", | ||||
qshift: str = "qshift_a", | ||||
nlevels: int = 3, | ||||
include_scale: bool = False, | ||||
axis: int = -1, | ||||
dtype: DTypeLike = "float64", | ||||
name: str = "C", | ||||
) -> None: | ||||
self.dims = _value_or_sized_to_tuple(dims) | ||||
self.ndim = len(self.dims) | ||||
self.nlevels = nlevels | ||||
self.include_scale = include_scale | ||||
self.axis = axis | ||||
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) | ||||
self._interpret_coeffs() | ||||
super().__init__( | ||||
dtype=np.dtype(dtype), | ||||
dims=self.dims, | ||||
dimsd=(self.coeff_array_size,), | ||||
name=name, | ||||
) | ||||
|
||||
def _interpret_coeffs(self): | ||||
T = np.ones(self.dims) | ||||
T = T.swapaxes(self.axis, -1) | ||||
self.swapped_dims = T.shape | ||||
T = self._nd_to_2d(T) | ||||
pyr = self._transform.forward( | ||||
cako marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
T , nlevels=self.nlevels, include_scale=True | ||||
) | ||||
self.coeff_array_size = 0 | ||||
self.lowpass_size = len(pyr.lowpass) | ||||
self.slices = [] | ||||
for _h in pyr.highpasses: | ||||
self.slices.append(len(_h)) | ||||
self.coeff_array_size += len(_h) | ||||
self.coeff_array_size += self.lowpass_size | ||||
elements = np.prod(T.shape[1:]) | ||||
self.coeff_array_size *= elements | ||||
self.lowpass_size *= elements | ||||
self.first_dim = elements | ||||
|
||||
def _nd_to_2d(self, arr_nd): | ||||
arr_2d = arr_nd.reshape((self.dims[0], -1)) | ||||
return arr_2d | ||||
|
||||
def _2d_to_nd(self, arr_2d): | ||||
arr_nd = arr_2d.reshape(self.swapped_dims) | ||||
return arr_nd | ||||
|
||||
def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: | ||||
coeffs = pyr.highpasses | ||||
flat_coeffs = [] | ||||
for band in coeffs: | ||||
for c in band: | ||||
flat_coeffs.append(c) | ||||
flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass)) | ||||
return flat_coeffs | ||||
|
||||
def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: | ||||
lowpass = np.array([x.real for x in X[-self.lowpass_size :]]).reshape( | ||||
(-1, self.first_dim) | ||||
) | ||||
_ptr = 0 | ||||
highpasses = () | ||||
for _sl in self.slices: | ||||
_h = X[_ptr : _ptr + (_sl * self.first_dim)] | ||||
_ptr += _sl * self.first_dim | ||||
_h = _h.reshape((-1, self.first_dim)) | ||||
highpasses += (_h,) | ||||
return dtcwt.Pyramid(lowpass, highpasses) | ||||
|
||||
def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid: | ||||
"""Return Pyramid object from transformed array | ||||
""" | ||||
return self._array_to_coeff(X) | ||||
|
||||
@reshaped | ||||
def _matvec(self, x: NDArray) -> NDArray: | ||||
x = x.swapaxes(self.axis, -1) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @this is not good, you should use
When something is readily available and used in many other operators we should not deviate from it unless there is a special requirement, I do not see it here :) |
||||
x = self._nd_to_2d(x) | ||||
return self._coeff_to_array( | ||||
self._transform.forward(x, nlevels=self.nlevels, include_scale=False) | ||||
) | ||||
|
||||
@reshaped | ||||
def _rmatvec(self, x: NDArray) -> NDArray: | ||||
Y = self._transform.inverse(self._array_to_coeff(x)) | ||||
Y = self._2d_to_nd(Y) | ||||
return Y.swapaxes(self.axis, -1) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pylops.signalprocessing import DTCWT | ||
|
||
par1 = {"ny": 10, "nx": 10, "dtype": "float64"} | ||
par2 = {"ny": 50, "nx": 50, "dtype": "float64"} | ||
|
||
|
||
def sequential_array(shape): | ||
num_elements = np.prod(shape) | ||
seq_array = np.arange(1, num_elements + 1) | ||
result = seq_array.reshape(shape) | ||
return result | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_input1D(par): | ||
"""Test for DTCWT with 1D input""" | ||
|
||
t = sequential_array((par["ny"],)) | ||
|
||
for levels in range(1, 10): | ||
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_input2D(par): | ||
"""Test for DTCWT with 2D input""" | ||
|
||
t = sequential_array((par["ny"], par["ny"],)) | ||
|
||
for levels in range(1, 10): | ||
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_input3D(par): | ||
"""Test for DTCWT with 3D input""" | ||
|
||
t = sequential_array((par["ny"], par["ny"], par["ny"])) | ||
|
||
for levels in range(1, 10): | ||
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) | ||
|
||
|
||
@pytest.mark.parametrize("par", [(par1), (par2)]) | ||
def test_dtcwt1D_birot(par): | ||
"""Test for DTCWT birot""" | ||
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"] | ||
|
||
t = sequential_array((par["ny"], par["ny"],)) | ||
|
||
for _b in birots: | ||
print(f"birot {_b}") | ||
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"]) | ||
x = Dtcwt @ t | ||
y = Dtcwt.H @ x | ||
|
||
np.testing.assert_allclose(t, y) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,3 +26,4 @@ isort | |
black | ||
flake8 | ||
mypy | ||
dtcwt |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,4 +26,5 @@ isort | |
black | ||
flake8 | ||
mypy | ||
pydata-sphinx-theme | ||
pydata-sphinx-theme | ||
dtcwt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to go all the way to the end of the title