Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDCT operator #610

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environment-dev-arm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- pip:
- devito
- dtcwt
- ucurv
- scikit-fmm
- spgl1
- pytest-runner
Expand All @@ -39,3 +40,4 @@ dependencies:
- image
- flake8
- mypy

1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies:
- pip:
- devito
- dtcwt
- ucurv
- scikit-fmm
- spgl1
- pytest-runner
Expand Down
48 changes: 48 additions & 0 deletions examples/plot_udct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Uniform Discrete Curvelet Transform
===================================
This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the
Uniform Discrete Curvelet Transform on a (multi-dimensional) input array.
"""

import numpy as np
from ucurv import udct, zoneplate, ucurvfwd, ucurvinv, ucurv2d_show
import matplotlib.pyplot as plt
import pylops
plt.close("all")

sz = [512, 512]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should try to write the entire example using the pylops.signalprocessing.UDCT methods (matvec/rmatvec) instead of those in the ucurv library. Of course, you can for example use the ucurv2d_show method but when you compute the forward and backward of your transform you should do this using the pylops operator.

cfg = [[3, 3], [6, 6]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to add some additional text explaining what you are doing, right now it is a bit hard to follow the example

res = len(cfg)
rsq = zoneplate(sz)
img = rsq - np.mean(rsq)

transform = udct(sz, cfg, complex=False, high="curvelet")

imband = ucurvfwd(img, transform)
plt.figure(figsize=(20, 60))
print(imband.keys())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this print meant to be there, maybe not whilst you make the plot?

plt.imshow(np.abs(ucurv2d_show(imband, transform)))
# plt.show()

recon = ucurvinv(imband, transform)

err = img - recon
print(np.max(np.abs(err)))
plt.figure(figsize=(20, 60))
plt.imshow(np.real(np.concatenate((img, recon, err), axis=1)))

plt.figure()
plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err))))
# plt.show()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try to always remove commented codes as these are unlikely useful for our users :)


################################################################################


sz = [256, 256]
cfg = [[3, 3], [6, 6]]
x = np.random.rand(256 * 256)
y = np.random.rand(262144)
F = pylops.signalprocessing.UDCT(sz, cfg)
print(np.dot(y, F * x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printing is probably not that useful... but anyways this will likely go away and you will use pylops @/* and .H @ / .H * to perform the forward and adjoint of the curvelet transform instead of the methods of ucurv above

print(np.dot(x, F.T * y))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer to use .H @ (or .H *) instead of .T *

2 changes: 2 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Patch2D 2D Patching transform operator.
Patch3D 3D Patching transform operator.
Fredholm1 Fredholm integral of first kind.
UDCT Uniform Discrete Curvelet Transform

"""

Expand Down Expand Up @@ -66,6 +67,7 @@
from .seislet import *
from .dct import *
from .dtcwt import *
from .udct import *


__all__ = [
Expand Down
63 changes: 63 additions & 0 deletions pylops/signalprocessing/udct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
__all__ = ["UDCT"]

import numpy as np

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils.decorators import reshaped
from pylops.utils.typing import NDArray
from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands

ucurv_message = deps.ucurv_import("the ucurv module")


class UDCT(LinearOperator):
r"""Uniform Discrete Curvelet Transform

Perform the multidimensional discrete curvelet transforms

The UDCT operator is a wraparound of the ucurvfwd and ucurvinv
calls in the UCURV package. Refer to
https://ucurv.readthedocs.io for a detailed description of the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this already available, so far I get 404?

input parameters.

Parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have only 3 parameters here, but they should be those of the init method, which are 5 in your case

----------
udct : :obj:`DTypeLike`, optional
Type of elements in input array.
dtype : :obj:`DTypeLike`, optional
Type of elements in input array.
name : :obj:`str`, optional
Name of operator (to be used by :func:`pylops.utils.describe.describe`)

Notes
-----
The UDCT operator applies the uniform discrete curvelet transform
in forward and adjoint modes from the ``ucurv`` library.

The ``ucurv`` library uses a udct object to represent all the parameters
of the multidimensional transform. The udct object have to be created with the size
of the data need to be transformed, and the cfg parameter which control the
number of resolution and direction.
"""
def __init__(self, sz, cfg, complex=False, sparse=False, dtype=None):
self.udct = udct(sz, cfg, complex, sparse)
self.shape = (self.udct.len, np.prod(sz))
self.dtype = np.dtype(dtype)
self.explicit = False
self.rmatvec_count = 0
self.matvec_count = 0

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
img = x.reshape(self.udct.sz)
band = ucurvfwd(img, self.udct)
bvec = bands2vec(band)
return bvec

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
band = vec2bands(x, self.udct)
recon = ucurvinv(band, self.udct)
recon2 = recon.reshape(self.udct.sz)
return recon2
20 changes: 19 additions & 1 deletion pylops/utils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"jax_enabled",
"devito_enabled",
"dtcwt_enabled",
"ucurv_enabled",
"numba_enabled",
"pyfftw_enabled",
"pywt_enabled",
Expand Down Expand Up @@ -76,7 +77,6 @@ def jax_import(message: Optional[str] = None) -> str:
"'pip install jax'; "
"for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
)

return jax_message


Expand Down Expand Up @@ -114,6 +114,23 @@ def dtcwt_import(message: Optional[str] = None) -> str:
return dtcwt_message


def ucurv_import(message: Optional[str] = None) -> str:
if ucurv_enabled:
try:
import ucurv # noqa: F401

ucurv_message = None
except Exception as e:
ucurv_message = f"Failed to import ucurv (error:{e})."
else:
ucurv_message = (
f"UCURV not available. "
f"In order to be able to use "
f'{message} run "pip install ucurv".'
)
return ucurv_message


def numba_import(message: Optional[str] = None) -> str:
if numba_enabled:
try:
Expand Down Expand Up @@ -238,6 +255,7 @@ def sympy_import(message: Optional[str] = None) -> str:
)
devito_enabled = util.find_spec("devito") is not None
dtcwt_enabled = util.find_spec("dtcwt") is not None
ucurv_enabled = util.find_spec("ucurv") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ advanced = [
"scikit-fmm",
"spgl1",
"dtcwt",
"ucurv",
]

[tool.setuptools.packages.find]
Expand Down
63 changes: 63 additions & 0 deletions pytests/test_udct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write all tests using only the pylops method and not the methods of ucurv...

Moreover, the bare minimum tests should have the dottest; in your case since you know the adjoint=inv, you can also add the assert that you have already.

import pytest

# from pylops.signalprocessing import UDCT

from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands

eps = 1e-6
shapes = [
[[256, 256], ],
[[32, 32, 32], ],
[[16, 16, 16, 16], ]
]

configurations = [
[[[3, 3]],
[[6, 6]],
[[12, 12]],
[[12, 12], [24, 24]],
[[12, 12], [3, 3], [6, 6]],
[[12, 12], [3, 3], [6, 6], [24, 24]]],
[[[3, 3, 3]],
[[6, 6, 6]],
[[12, 12, 12]],
[[12, 12, 12], [24, 24, 24]]],
# [[12, 12, 12], [3, 3, 3], [6, 6, 6]],
# [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]],

[[[3, 3, 3, 3]]],
# [[6, 6, 6, 6]],
# [[12, 12, 12, 12]],
# [[12, 12, 12, 12], [24, 24, 24, 24]],
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]],
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]],
]

combinations = [
(shape, config)
for shape_list, config_list in zip(shapes, configurations)
for shape in shape_list
for config in config_list]


@pytest.mark.parametrize("shape, cfg", combinations)
def test_ucurv(shape, cfg):
data = np.random.rand(*shape)
tf = udct(shape, cfg)
band = ucurvfwd(data, tf)
recon = ucurvinv(band, tf)
are_close = np.all(np.isclose(data, recon, atol=eps))
assert are_close


@pytest.mark.parametrize("shape, cfg", combinations)
def test_vectorize(shape, cfg):
data = np.random.rand(*shape)
tf = udct(shape, cfg)
band = ucurvfwd(data, tf)
flat = bands2vec(band)
unflat = vec2bands(flat, tf)
recon = ucurvinv(unflat, tf)
are_close = np.all(np.isclose(data, recon, atol=eps))
assert are_close
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ scikit-fmm
sympy
devito
dtcwt
ucurv
matplotlib
ipython
pytest
Expand Down
1 change: 1 addition & 0 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ scikit-fmm
sympy
devito
dtcwt
ucurv
matplotlib
ipython
pytest
Expand Down