Skip to content

Commit

Permalink
commit main files
Browse files Browse the repository at this point in the history
  • Loading branch information
yud08 committed Sep 1, 2024
1 parent 9d6748b commit b367ea9
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 0 deletions.
49 changes: 49 additions & 0 deletions examples/plot_udct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Uniform Discrete Curvelet Transform
===================================
This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the
Uniform Discrete Curvelet Transform on a (multi-dimensional) input array.
"""


from ucurv import *
import matplotlib.pyplot as plt
import pylops
plt.close("all")

if False:
sz = [512, 512]
cfg = [[3, 3], [6,6]]
res = len(cfg)
rsq = zoneplate(sz)
img = rsq - np.mean(rsq)

transform = udct(sz, cfg, complex = False, high = "curvelet")

imband = ucurvfwd(img, transform)
plt.figure(figsize = (20, 60))
print(imband.keys())
plt.imshow(np.abs(ucurv2d_show(imband, transform)))
# plt.show()

recon = ucurvinv(imband, transform)

err = img - recon
print(np.max(np.abs(err)))
plt.figure(figsize = (20, 60))
plt.imshow(np.real(np.concatenate((img, recon, err), axis = 1)))

plt.figure()
plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err))))
# plt.show()

################################################################################


sz = [256, 256]
cfg = [[3,3],[6,6]]
x = np.random.rand(256*256)
y = np.random.rand(262144)
F = pylops.signalprocessing.UDCT(sz,cfg)
print(np.dot(y,F*x))
print(np.dot(x,F.T*y))
34 changes: 34 additions & 0 deletions pylops/signalprocessing/udct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
__all__ = ["UDCT"]

from typing import Any, NewType, Union

import numpy as np

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils._internal import _value_or_sized_to_tuple
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
from ucurv import *

ucurv_message = deps.ucurv_import("the ucurv module")

class UDCT(LinearOperator):
def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None):
self.udct = udct(sz, cfg, complex, sparse)
self.shape = (self.udct.len, np.prod(sz))
self.dtype = np.dtype(dtype)
self.explicit = False
self.rmatvec_count = 0
self.matvec_count = 0
def _matvec(self, x):
img = x.reshape(self.udct.sz)
band = ucurvfwd(img, self.udct)
bvec = bands2vec(band)
return bvec

def _rmatvec(self, x):
band = vec2bands(x, self.udct)
recon = ucurvinv(band, self.udct)
recon2 = recon.reshape(self.udct.sz)
return recon2
67 changes: 67 additions & 0 deletions pytests/test_udct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pytest

from pylops.signalprocessing import UDCT

from ucurv import *

import numpy as np

eps = 1e-6
shapes = [
[[256, 256], ],
[[32, 32, 32], ],
[[16, 16, 16, 16], ]
]

configurations = [
[[[3, 3]],
[[6, 6]],
[[12, 12]],
[[12, 12], [24, 24]],
[[12, 12], [3, 3], [6, 6]],
[[12, 12], [3, 3], [6, 6], [24, 24]],
],
[[[3, 3, 3]],
[[6, 6, 6]],
[[12, 12, 12]],
[[12, 12, 12], [24, 24, 24]],
# [[12, 12, 12], [3, 3, 3], [6, 6, 6]],
# [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]],
],
[[[3, 3, 3, 3]],
# [[6, 6, 6, 6]],
# [[12, 12, 12, 12]],
# [[12, 12, 12, 12], [24, 24, 24, 24]],
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]],
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]],
],
]

combinations = [
(shape, config)
for shape_list, config_list in zip(shapes, configurations)
for shape in shape_list
for config in config_list
]

@pytest.mark.parametrize("shape, cfg", combinations)
def test_ucurv(shape, cfg):
data = np.random.rand(*shape)
tf = udct(shape, cfg)
band = ucurvfwd(data, tf)
recon = ucurvinv(band, tf)
are_close = np.all(np.isclose(data, recon, atol=eps))
assert(are_close == True)

@pytest.mark.parametrize("shape, cfg", combinations)
def test_vectorize(shape, cfg):
data = np.random.rand(*shape)
tf = udct(shape, cfg)
band = ucurvfwd(data, tf)
flat = bands2vec(band)
unflat = vec2bands(flat, tf)
recon = ucurvinv(band, tf)
are_close = np.all(np.isclose(data, recon, atol=eps))
assert(are_close == True)

0 comments on commit b367ea9

Please sign in to comment.