From f3a792764076366cde427636ca1bc3f9ce7054e2 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Tue, 3 Sep 2024 16:43:42 +0100 Subject: [PATCH] flake8 correction --- examples/plot_udct.py | 53 ++++++++++++++++----------------- pylops/signalprocessing/udct.py | 51 ++++++++++++++++++++++++------- pylops/utils/deps.py | 9 ++++-- pytests/test_udct.py | 42 ++++++++++++-------------- 4 files changed, 91 insertions(+), 64 deletions(-) diff --git a/examples/plot_udct.py b/examples/plot_udct.py index 35aac56e..cf92035a 100644 --- a/examples/plot_udct.py +++ b/examples/plot_udct.py @@ -5,45 +5,44 @@ Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. """ - -from ucurv import * +import numpy as np +from ucurv import udct, zoneplate, ucurvfwd, ucurvinv, ucurv2d_show import matplotlib.pyplot as plt import pylops plt.close("all") -if False: - sz = [512, 512] - cfg = [[3, 3], [6,6]] - res = len(cfg) - rsq = zoneplate(sz) - img = rsq - np.mean(rsq) +sz = [512, 512] +cfg = [[3, 3], [6, 6]] +res = len(cfg) +rsq = zoneplate(sz) +img = rsq - np.mean(rsq) - transform = udct(sz, cfg, complex = False, high = "curvelet") +transform = udct(sz, cfg, complex=False, high="curvelet") - imband = ucurvfwd(img, transform) - plt.figure(figsize = (20, 60)) - print(imband.keys()) - plt.imshow(np.abs(ucurv2d_show(imband, transform))) - # plt.show() +imband = ucurvfwd(img, transform) +plt.figure(figsize=(20, 60)) +print(imband.keys()) +plt.imshow(np.abs(ucurv2d_show(imband, transform))) +# plt.show() - recon = ucurvinv(imband, transform) +recon = ucurvinv(imband, transform) - err = img - recon - print(np.max(np.abs(err))) - plt.figure(figsize = (20, 60)) - plt.imshow(np.real(np.concatenate((img, recon, err), axis = 1))) +err = img - recon +print(np.max(np.abs(err))) +plt.figure(figsize=(20, 60)) +plt.imshow(np.real(np.concatenate((img, recon, err), axis=1))) - plt.figure() - plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) - # plt.show() +plt.figure() +plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) +# plt.show() ################################################################################ sz = [256, 256] -cfg = [[3,3],[6,6]] -x = np.random.rand(256*256) +cfg = [[3, 3], [6, 6]] +x = np.random.rand(256 * 256) y = np.random.rand(262144) -F = pylops.signalprocessing.UDCT(sz,cfg) -print(np.dot(y,F*x)) -print(np.dot(x,F.T*y)) \ No newline at end of file +F = pylops.signalprocessing.UDCT(sz, cfg) +print(np.dot(y, F * x)) +print(np.dot(x, F.T * y)) diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py index 8d2ef5b1..ccafb5db 100644 --- a/pylops/signalprocessing/udct.py +++ b/pylops/signalprocessing/udct.py @@ -1,34 +1,63 @@ __all__ = ["UDCT"] -from typing import Any, NewType, Union - import numpy as np from pylops import LinearOperator from pylops.utils import deps -from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.decorators import reshaped -from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray -from ucurv import * +from pylops.utils.typing import NDArray +from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands ucurv_message = deps.ucurv_import("the ucurv module") + class UDCT(LinearOperator): - def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): + r"""Uniform Discrete Curvelet Transform + + Perform the multidimensional discrete curvelet transforms + + The UDCT operator is a wraparound of the ucurvfwd and ucurvinv + calls in the UCURV package. Refer to + https://ucurv.readthedocs.io for a detailed description of the + input parameters. + + Parameters + ---------- + udct : :obj:`DTypeLike`, optional + Type of elements in input array. + dtype : :obj:`DTypeLike`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Notes + ----- + The UDCT operator applies the uniform discrete curvelet transform + in forward and adjoint modes from the ``ucurv`` library. + + The ``ucurv`` library uses a udct object to represent all the parameters + of the multidimensional transform. The udct object have to be created with the size + of the data need to be transformed, and the cfg parameter which control the + number of resolution and direction. + """ + def __init__(self, sz, cfg, complex=False, sparse=False, dtype=None): self.udct = udct(sz, cfg, complex, sparse) self.shape = (self.udct.len, np.prod(sz)) self.dtype = np.dtype(dtype) - self.explicit = False + self.explicit = False self.rmatvec_count = 0 self.matvec_count = 0 - def _matvec(self, x): + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: img = x.reshape(self.udct.sz) band = ucurvfwd(img, self.udct) - bvec = bands2vec(band) + bvec = bands2vec(band) return bvec - def _rmatvec(self, x): + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: band = vec2bands(x, self.udct) recon = ucurvinv(band, self.udct) recon2 = recon.reshape(self.udct.sz) - return recon2 \ No newline at end of file + return recon2 diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index b320028b..744b5aa3 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,9 +1,9 @@ __all__ = [ "cupy_enabled", - "jax_enabled", + "jax_enabled", "devito_enabled", "dtcwt_enabled", - "ucurv_enabled", + "ucurv_enabled", "numba_enabled", "pyfftw_enabled", "pywt_enabled", @@ -52,6 +52,7 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message + def jax_import(message: Optional[str] = None) -> str: jax_test = ( util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 @@ -76,9 +77,9 @@ def jax_import(message: Optional[str] = None) -> str: "'pip install jax'; " "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" ) - return jax_message + def devito_import(message: Optional[str] = None) -> str: if devito_enabled: try: @@ -112,6 +113,7 @@ def dtcwt_import(message: Optional[str] = None) -> str: ) return dtcwt_message + def ucurv_import(message: Optional[str] = None) -> str: if ucurv_enabled: try: @@ -128,6 +130,7 @@ def ucurv_import(message: Optional[str] = None) -> str: ) return ucurv_message + def numba_import(message: Optional[str] = None) -> str: if numba_enabled: try: diff --git a/pytests/test_udct.py b/pytests/test_udct.py index c3393fe6..4f194b20 100644 --- a/pytests/test_udct.py +++ b/pytests/test_udct.py @@ -1,11 +1,9 @@ import numpy as np import pytest -from pylops.signalprocessing import UDCT +# from pylops.signalprocessing import UDCT -from ucurv import * - -import numpy as np +from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands eps = 1e-6 shapes = [ @@ -15,35 +13,33 @@ ] configurations = [ - [[[3, 3]], + [[[3, 3]], [[6, 6]], [[12, 12]], [[12, 12], [24, 24]], [[12, 12], [3, 3], [6, 6]], - [[12, 12], [3, 3], [6, 6], [24, 24]], - ], - [[[3, 3, 3]], + [[12, 12], [3, 3], [6, 6], [24, 24]]], + [[[3, 3, 3]], [[6, 6, 6]], [[12, 12, 12]], - [[12, 12, 12], [24, 24, 24]], + [[12, 12, 12], [24, 24, 24]]], # [[12, 12, 12], [3, 3, 3], [6, 6, 6]], - # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], - ], - [[[3, 3, 3, 3]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], + + [[[3, 3, 3, 3]]], # [[6, 6, 6, 6]], # [[12, 12, 12, 12]], # [[12, 12, 12, 12], [24, 24, 24, 24]], # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], - # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], - ], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], ] combinations = [ - (shape, config) - for shape_list, config_list in zip(shapes, configurations) - for shape in shape_list - for config in config_list -] + (shape, config) + for shape_list, config_list in zip(shapes, configurations) + for shape in shape_list + for config in config_list] + @pytest.mark.parametrize("shape, cfg", combinations) def test_ucurv(shape, cfg): @@ -52,7 +48,8 @@ def test_ucurv(shape, cfg): band = ucurvfwd(data, tf) recon = ucurvinv(band, tf) are_close = np.all(np.isclose(data, recon, atol=eps)) - assert(are_close == True) + assert are_close + @pytest.mark.parametrize("shape, cfg", combinations) def test_vectorize(shape, cfg): @@ -61,7 +58,6 @@ def test_vectorize(shape, cfg): band = ucurvfwd(data, tf) flat = bands2vec(band) unflat = vec2bands(flat, tf) - recon = ucurvinv(band, tf) + recon = ucurvinv(unflat, tf) are_close = np.all(np.isclose(data, recon, atol=eps)) - assert(are_close == True) - + assert are_close