diff --git a/examples/plot_udct.py b/examples/plot_udct.py new file mode 100644 index 00000000..35aac56e --- /dev/null +++ b/examples/plot_udct.py @@ -0,0 +1,49 @@ +""" +Uniform Discrete Curvelet Transform +=================================== +This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the +Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. +""" + + +from ucurv import * +import matplotlib.pyplot as plt +import pylops +plt.close("all") + +if False: + sz = [512, 512] + cfg = [[3, 3], [6,6]] + res = len(cfg) + rsq = zoneplate(sz) + img = rsq - np.mean(rsq) + + transform = udct(sz, cfg, complex = False, high = "curvelet") + + imband = ucurvfwd(img, transform) + plt.figure(figsize = (20, 60)) + print(imband.keys()) + plt.imshow(np.abs(ucurv2d_show(imband, transform))) + # plt.show() + + recon = ucurvinv(imband, transform) + + err = img - recon + print(np.max(np.abs(err))) + plt.figure(figsize = (20, 60)) + plt.imshow(np.real(np.concatenate((img, recon, err), axis = 1))) + + plt.figure() + plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) + # plt.show() + +################################################################################ + + +sz = [256, 256] +cfg = [[3,3],[6,6]] +x = np.random.rand(256*256) +y = np.random.rand(262144) +F = pylops.signalprocessing.UDCT(sz,cfg) +print(np.dot(y,F*x)) +print(np.dot(x,F.T*y)) \ No newline at end of file diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py new file mode 100644 index 00000000..8d2ef5b1 --- /dev/null +++ b/pylops/signalprocessing/udct.py @@ -0,0 +1,34 @@ +__all__ = ["UDCT"] + +from typing import Any, NewType, Union + +import numpy as np + +from pylops import LinearOperator +from pylops.utils import deps +from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.decorators import reshaped +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray +from ucurv import * + +ucurv_message = deps.ucurv_import("the ucurv module") + +class UDCT(LinearOperator): + def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): + self.udct = udct(sz, cfg, complex, sparse) + self.shape = (self.udct.len, np.prod(sz)) + self.dtype = np.dtype(dtype) + self.explicit = False + self.rmatvec_count = 0 + self.matvec_count = 0 + def _matvec(self, x): + img = x.reshape(self.udct.sz) + band = ucurvfwd(img, self.udct) + bvec = bands2vec(band) + return bvec + + def _rmatvec(self, x): + band = vec2bands(x, self.udct) + recon = ucurvinv(band, self.udct) + recon2 = recon.reshape(self.udct.sz) + return recon2 \ No newline at end of file diff --git a/pytests/test_udct.py b/pytests/test_udct.py new file mode 100644 index 00000000..c3393fe6 --- /dev/null +++ b/pytests/test_udct.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from pylops.signalprocessing import UDCT + +from ucurv import * + +import numpy as np + +eps = 1e-6 +shapes = [ + [[256, 256], ], + [[32, 32, 32], ], + [[16, 16, 16, 16], ] +] + +configurations = [ + [[[3, 3]], + [[6, 6]], + [[12, 12]], + [[12, 12], [24, 24]], + [[12, 12], [3, 3], [6, 6]], + [[12, 12], [3, 3], [6, 6], [24, 24]], + ], + [[[3, 3, 3]], + [[6, 6, 6]], + [[12, 12, 12]], + [[12, 12, 12], [24, 24, 24]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], + ], + [[[3, 3, 3, 3]], + # [[6, 6, 6, 6]], + # [[12, 12, 12, 12]], + # [[12, 12, 12, 12], [24, 24, 24, 24]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], + ], +] + +combinations = [ + (shape, config) + for shape_list, config_list in zip(shapes, configurations) + for shape in shape_list + for config in config_list +] + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_ucurv(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert(are_close == True) + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_vectorize(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + flat = bands2vec(band) + unflat = vec2bands(flat, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert(are_close == True) +