-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UDCT operator #610
base: dev
Are you sure you want to change the base?
UDCT operator #610
Changes from all commits
7c58f0f
b99e8f6
3666662
934a154
9d6748b
b367ea9
c45c7bf
8bb63c0
f3a7927
79b9128
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ dependencies: | |
- pip: | ||
- devito | ||
- dtcwt | ||
- ucurv | ||
- scikit-fmm | ||
- spgl1 | ||
- pytest-runner | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" | ||
Uniform Discrete Curvelet Transform | ||
=================================== | ||
This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the | ||
Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. | ||
""" | ||
|
||
import numpy as np | ||
from ucurv import udct, zoneplate, ucurvfwd, ucurvinv, ucurv2d_show | ||
import matplotlib.pyplot as plt | ||
import pylops | ||
plt.close("all") | ||
|
||
sz = [512, 512] | ||
cfg = [[3, 3], [6, 6]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to add some additional text explaining what you are doing, right now it is a bit hard to follow the example |
||
res = len(cfg) | ||
rsq = zoneplate(sz) | ||
img = rsq - np.mean(rsq) | ||
|
||
transform = udct(sz, cfg, complex=False, high="curvelet") | ||
|
||
imband = ucurvfwd(img, transform) | ||
plt.figure(figsize=(20, 60)) | ||
print(imband.keys()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this print meant to be there, maybe not whilst you make the plot? |
||
plt.imshow(np.abs(ucurv2d_show(imband, transform))) | ||
# plt.show() | ||
|
||
recon = ucurvinv(imband, transform) | ||
|
||
err = img - recon | ||
print(np.max(np.abs(err))) | ||
plt.figure(figsize=(20, 60)) | ||
plt.imshow(np.real(np.concatenate((img, recon, err), axis=1))) | ||
|
||
plt.figure() | ||
plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) | ||
# plt.show() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try to always remove commented codes as these are unlikely useful for our users :) |
||
|
||
################################################################################ | ||
|
||
|
||
sz = [256, 256] | ||
cfg = [[3, 3], [6, 6]] | ||
x = np.random.rand(256 * 256) | ||
y = np.random.rand(262144) | ||
F = pylops.signalprocessing.UDCT(sz, cfg) | ||
print(np.dot(y, F * x)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. printing is probably not that useful... but anyways this will likely go away and you will use pylops @/* and .H @ / .H * to perform the forward and adjoint of the curvelet transform instead of the methods of ucurv above |
||
print(np.dot(x, F.T * y)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We prefer to use .H @ (or .H *) instead of .T * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
__all__ = ["UDCT"] | ||
|
||
import numpy as np | ||
|
||
from pylops import LinearOperator | ||
from pylops.utils import deps | ||
from pylops.utils.decorators import reshaped | ||
from pylops.utils.typing import NDArray | ||
from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands | ||
|
||
ucurv_message = deps.ucurv_import("the ucurv module") | ||
|
||
|
||
class UDCT(LinearOperator): | ||
r"""Uniform Discrete Curvelet Transform | ||
|
||
Perform the multidimensional discrete curvelet transforms | ||
|
||
The UDCT operator is a wraparound of the ucurvfwd and ucurvinv | ||
calls in the UCURV package. Refer to | ||
https://ucurv.readthedocs.io for a detailed description of the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this already available, so far I get 404? |
||
input parameters. | ||
|
||
Parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have only 3 parameters here, but they should be those of the init method, which are 5 in your case |
||
---------- | ||
udct : :obj:`DTypeLike`, optional | ||
Type of elements in input array. | ||
dtype : :obj:`DTypeLike`, optional | ||
Type of elements in input array. | ||
name : :obj:`str`, optional | ||
Name of operator (to be used by :func:`pylops.utils.describe.describe`) | ||
|
||
Notes | ||
----- | ||
The UDCT operator applies the uniform discrete curvelet transform | ||
in forward and adjoint modes from the ``ucurv`` library. | ||
|
||
The ``ucurv`` library uses a udct object to represent all the parameters | ||
of the multidimensional transform. The udct object have to be created with the size | ||
of the data need to be transformed, and the cfg parameter which control the | ||
number of resolution and direction. | ||
""" | ||
def __init__(self, sz, cfg, complex=False, sparse=False, dtype=None): | ||
self.udct = udct(sz, cfg, complex, sparse) | ||
self.shape = (self.udct.len, np.prod(sz)) | ||
self.dtype = np.dtype(dtype) | ||
self.explicit = False | ||
self.rmatvec_count = 0 | ||
self.matvec_count = 0 | ||
|
||
@reshaped | ||
def _matvec(self, x: NDArray) -> NDArray: | ||
img = x.reshape(self.udct.sz) | ||
band = ucurvfwd(img, self.udct) | ||
bvec = bands2vec(band) | ||
return bvec | ||
|
||
@reshaped | ||
def _rmatvec(self, x: NDArray) -> NDArray: | ||
band = vec2bands(x, self.udct) | ||
recon = ucurvinv(band, self.udct) | ||
recon2 = recon.reshape(self.udct.sz) | ||
return recon2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ advanced = [ | |
"scikit-fmm", | ||
"spgl1", | ||
"dtcwt", | ||
"ucurv", | ||
] | ||
|
||
[tool.setuptools.packages.find] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Write all tests using only the pylops method and not the methods of Moreover, the bare minimum tests should have the dottest; in your case since you know the adjoint=inv, you can also add the assert that you have already. |
||
import pytest | ||
|
||
# from pylops.signalprocessing import UDCT | ||
|
||
from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands | ||
|
||
eps = 1e-6 | ||
shapes = [ | ||
[[256, 256], ], | ||
[[32, 32, 32], ], | ||
[[16, 16, 16, 16], ] | ||
] | ||
|
||
configurations = [ | ||
[[[3, 3]], | ||
[[6, 6]], | ||
[[12, 12]], | ||
[[12, 12], [24, 24]], | ||
[[12, 12], [3, 3], [6, 6]], | ||
[[12, 12], [3, 3], [6, 6], [24, 24]]], | ||
[[[3, 3, 3]], | ||
[[6, 6, 6]], | ||
[[12, 12, 12]], | ||
[[12, 12, 12], [24, 24, 24]]], | ||
# [[12, 12, 12], [3, 3, 3], [6, 6, 6]], | ||
# [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], | ||
|
||
[[[3, 3, 3, 3]]], | ||
# [[6, 6, 6, 6]], | ||
# [[12, 12, 12, 12]], | ||
# [[12, 12, 12, 12], [24, 24, 24, 24]], | ||
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], | ||
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], | ||
] | ||
|
||
combinations = [ | ||
(shape, config) | ||
for shape_list, config_list in zip(shapes, configurations) | ||
for shape in shape_list | ||
for config in config_list] | ||
|
||
|
||
@pytest.mark.parametrize("shape, cfg", combinations) | ||
def test_ucurv(shape, cfg): | ||
data = np.random.rand(*shape) | ||
tf = udct(shape, cfg) | ||
band = ucurvfwd(data, tf) | ||
recon = ucurvinv(band, tf) | ||
are_close = np.all(np.isclose(data, recon, atol=eps)) | ||
assert are_close | ||
|
||
|
||
@pytest.mark.parametrize("shape, cfg", combinations) | ||
def test_vectorize(shape, cfg): | ||
data = np.random.rand(*shape) | ||
tf = udct(shape, cfg) | ||
band = ucurvfwd(data, tf) | ||
flat = bands2vec(band) | ||
unflat = vec2bands(flat, tf) | ||
recon = ucurvinv(unflat, tf) | ||
are_close = np.all(np.isclose(data, recon, atol=eps)) | ||
assert are_close |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ scikit-fmm | |
sympy | ||
devito | ||
dtcwt | ||
ucurv | ||
matplotlib | ||
ipython | ||
pytest | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ scikit-fmm | |
sympy | ||
devito | ||
dtcwt | ||
ucurv | ||
matplotlib | ||
ipython | ||
pytest | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should try to write the entire example using the pylops.signalprocessing.UDCT methods (matvec/rmatvec) instead of those in the ucurv library. Of course, you can for example use the
ucurv2d_show
method but when you compute the forward and backward of your transform you should do this using the pylops operator.