Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added DTCWT operator #495

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Signal processing
DWT
DWT2D
DCT
DTCWT
Seislet
Radon2D
Radon3D
Expand Down
74 changes: 74 additions & 0 deletions examples/plot_dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
Dual-Tree Complex Wavelet Transform
=========================
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to go all the way to the end of the title

This example shows how to use the :py:class:`pylops.signalprocessing.DCT` operator.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be pylops.signalprocessing.DTCWT

This operator performs the 1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional)
input array.
"""

import matplotlib.pyplot as plt
import numpy as np

import pylops

plt.close("all")

###############################################################################
# Let's define a 1D array x of having random values

n = 50
x = np.random.rand(n,)

###############################################################################
# We create the DTCWT operator with shape of our input array. DTCWT transform
# gives a Pyramid object that is flattened out `y`.

DOp = pylops.signalprocessing.DTCWT(dims=x.shape)
y = DOp @ x
xadj = DOp.H @ y

plt.figure(figsize=(8, 5))
plt.plot(x, "k", label="input array")
plt.plot(y, "r", label="transformed array")
plt.plot(xadj, "--b", label="transformed array")
plt.title("Dual-Tree Complex Wavelet Transform 1D")
plt.legend()
plt.tight_layout()

#################################################################################
# To get the Pyramid object use the `get_pyramid` method.
# We can get the Highpass signal and Lowpass signal from it

pyr = DOp.get_pyramid(y)

plt.figure(figsize=(10, 5))
plt.plot(x, "--b", label="orignal signal")
plt.plot(pyr.lowpass, "k", label="lowpass")
plt.plot(pyr.highpasses[0], "r", label="highpass level 1 signal")
plt.plot(pyr.highpasses[1], "b", label="highpass level 2 signal")
plt.plot(pyr.highpasses[2], "g", label="highpass level 3 signal")

plt.title("DTCWT Pyramid Object")
plt.legend()
plt.tight_layout()

###################################################################################
# DTCWT can also be performed on multi-dimension arrays. The number of levels can also
# be defined using the `nlevels`

n = 10
m = 2

x = np.random.rand(n, m)

DOp = pylops.signalprocessing.DTCWT(dims=x.shape, nlevels=5)
y = DOp @ x
xadj = DOp.H @ y

plt.figure(figsize=(8, 5))
plt.plot(x, "k", label="input array")
plt.plot(y, "r", label="transformed array")
plt.plot(xadj, "--b", label="transformed array")
plt.title("Dual-Tree Complex Wavelet Transform 1D on ND array")
plt.legend()
plt.tight_layout()
3 changes: 3 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transforms
Seislet Two dimensional Seislet operator.
Radon2D Two dimensional Radon transform.
Radon3D Three dimensional Radon transform.
Expand Down Expand Up @@ -60,6 +61,7 @@
from .dwt2d import *
from .seislet import *
from .dct import *
from .dtcwt import *

__all__ = [
"FFT",
Expand Down Expand Up @@ -89,4 +91,5 @@
"DWT2D",
"Seislet",
"DCT",
"DTCWT",
]
147 changes: 147 additions & 0 deletions pylops/signalprocessing/dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
__all__ = ["DTCWT"]

from typing import Union

import dtcwt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be wrapped in a pylops.utils.deps check like here:

https://github.com/PyLops/pylops/blob/dev/pylops/_torchoperator.py#LL5-LL10

import numpy as np

from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray


class DTCWT(LinearOperator):
r"""Dual-Tree Complex Wavelet Transform
Perform 1D Dual-Tree Complex Wavelet Transform on a given array.

This operator wraps around :py:func:`dtcwt` package.

Parameters
----------
dims: :obj:`int` or :obj:`tuple`
Number of samples for each dimension.
birot: :obj:`str`, optional
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot()`. Default is `"near_sym_a"`.
qshift: :obj:`str`, optional
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift()`. Default is `"qshift_a"`
nlevels: :obj:`int`, optional
Number of levels of wavelet decomposition. Default is 3.
include_scale: :obj:`bool`, optional
Include scales in pyramid. See :py:func:`dtcwt.Pyramid`. Default is False.
axis: :obj:`int`, optional
Axis on which the transform is performed. Default is -1.
dtype : :obj:`DTypeLike`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)


Notes
-----
The :py:func:`dtcwt` library uses a Pyramid object to represent the transformed domain signal.
It has
cako marked this conversation as resolved.
Show resolved Hide resolved
- `lowpass` (coarsest scale lowpass signal)
- `highpasses` (complex subband coefficients for corresponding scales)
- `scales` (lowpass signal for corresponding scales finest to coarsest)

To make the dtcwt forward() and inverse() functions compatible with pylops, the Pyramid object is
cako marked this conversation as resolved.
Show resolved Hide resolved
flattened out and all coefficents (high-pass and low pass coefficients) are appended into one array using the
`_coeff_to_array` method.
For inverse, the flattened array is used to reconstruct the Pyramid object using the `_array_to_coeff`
method and then inverse is performed.
"""

def __init__(
self,
dims: Union[int, InputDimsLike],
biort: str = "near_sym_a",
qshift: str = "qshift_a",
nlevels: int = 3,
include_scale: bool = False,
axis: int = -1,
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
self.dims = _value_or_sized_to_tuple(dims)
self.ndim = len(self.dims)
self.nlevels = nlevels
self.include_scale = include_scale
self.axis = axis
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift)
self._interpret_coeffs()
super().__init__(
dtype=np.dtype(dtype),
dims=self.dims,
dimsd=(self.coeff_array_size,),
name=name,
)

def _interpret_coeffs(self):
T = np.ones(self.dims)
T = T.swapaxes(self.axis, -1)
self.swapped_dims = T.shape
T = self._nd_to_2d(T)
pyr = self._transform.forward(
cako marked this conversation as resolved.
Show resolved Hide resolved
T , nlevels=self.nlevels, include_scale=True
)
self.coeff_array_size = 0
self.lowpass_size = len(pyr.lowpass)
self.slices = []
for _h in pyr.highpasses:
self.slices.append(len(_h))
self.coeff_array_size += len(_h)
self.coeff_array_size += self.lowpass_size
elements = np.prod(T.shape[1:])
self.coeff_array_size *= elements
self.lowpass_size *= elements
self.first_dim = elements

def _nd_to_2d(self, arr_nd):
arr_2d = arr_nd.reshape((self.dims[0], -1))
return arr_2d

def _2d_to_nd(self, arr_2d):
arr_nd = arr_2d.reshape(self.swapped_dims)
return arr_nd

def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray:
coeffs = pyr.highpasses
flat_coeffs = []
for band in coeffs:
for c in band:
flat_coeffs.append(c)
flat_coeffs = np.concatenate((flat_coeffs, pyr.lowpass))
return flat_coeffs

def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid:
lowpass = np.array([x.real for x in X[-self.lowpass_size :]]).reshape(
(-1, self.first_dim)
)
_ptr = 0
highpasses = ()
for _sl in self.slices:
_h = X[_ptr : _ptr + (_sl * self.first_dim)]
_ptr += _sl * self.first_dim
_h = _h.reshape((-1, self.first_dim))
highpasses += (_h,)
return dtcwt.Pyramid(lowpass, highpasses)

def get_pyramid(self, X: NDArray) -> dtcwt.Pyramid:
"""Return Pyramid object from transformed array
"""
return self._array_to_coeff(X)

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
x = x.swapaxes(self.axis, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@this is not good, you should use swapaxes like here

@reshaped(swapaxis=True)

When something is readily available and used in many other operators we should not deviate from it unless there is a special requirement, I do not see it here :)

x = self._nd_to_2d(x)
return self._coeff_to_array(
self._transform.forward(x, nlevels=self.nlevels, include_scale=False)
)

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
Y = self._transform.inverse(self._array_to_coeff(x))
Y = self._2d_to_nd(Y)
return Y.swapaxes(self.axis, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

72 changes: 72 additions & 0 deletions pytests/test_dtcwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
import pytest

from pylops.signalprocessing import DTCWT

par1 = {"ny": 10, "nx": 10, "dtype": "float64"}
par2 = {"ny": 50, "nx": 50, "dtype": "float64"}


def sequential_array(shape):
num_elements = np.prod(shape)
seq_array = np.arange(1, num_elements + 1)
result = seq_array.reshape(shape)
return result


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input1D(par):
"""Test for DTCWT with 1D input"""

t = sequential_array((par["ny"],))

for levels in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input2D(par):
"""Test for DTCWT with 2D input"""

t = sequential_array((par["ny"], par["ny"],))

for levels in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input3D(par):
"""Test for DTCWT with 3D input"""

t = sequential_array((par["ny"], par["ny"], par["ny"]))

for levels in range(1, 10):
Dtcwt = DTCWT(dims=t.shape, nlevels=levels, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)


@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_birot(par):
"""Test for DTCWT birot"""
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]

t = sequential_array((par["ny"], par["ny"],))

for _b in birots:
print(f"birot {_b}")
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
x = Dtcwt @ t
y = Dtcwt.H @ x

np.testing.assert_allclose(t, y)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ isort
black
flake8
mypy
dtcwt
3 changes: 2 additions & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ isort
black
flake8
mypy
pydata-sphinx-theme
pydata-sphinx-theme
dtcwt