-
Notifications
You must be signed in to change notification settings - Fork 105
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
""" | ||
Uniform Discrete Curvelet Transform | ||
=================================== | ||
This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the | ||
Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. | ||
""" | ||
|
||
|
||
from ucurv import * | ||
import matplotlib.pyplot as plt | ||
import pylops | ||
plt.close("all") | ||
|
||
if False: | ||
sz = [512, 512] | ||
cfg = [[3, 3], [6,6]] | ||
res = len(cfg) | ||
rsq = zoneplate(sz) | ||
img = rsq - np.mean(rsq) | ||
|
||
transform = udct(sz, cfg, complex = False, high = "curvelet") | ||
|
||
imband = ucurvfwd(img, transform) | ||
plt.figure(figsize = (20, 60)) | ||
print(imband.keys()) | ||
plt.imshow(np.abs(ucurv2d_show(imband, transform))) | ||
# plt.show() | ||
|
||
recon = ucurvinv(imband, transform) | ||
|
||
err = img - recon | ||
print(np.max(np.abs(err))) | ||
plt.figure(figsize = (20, 60)) | ||
plt.imshow(np.real(np.concatenate((img, recon, err), axis = 1))) | ||
|
||
plt.figure() | ||
plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) | ||
# plt.show() | ||
|
||
################################################################################ | ||
|
||
|
||
sz = [256, 256] | ||
cfg = [[3,3],[6,6]] | ||
x = np.random.rand(256*256) | ||
y = np.random.rand(262144) | ||
F = pylops.signalprocessing.UDCT(sz,cfg) | ||
print(np.dot(y,F*x)) | ||
print(np.dot(x,F.T*y)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
__all__ = ["UDCT"] | ||
|
||
from typing import Any, NewType, Union | ||
|
||
import numpy as np | ||
|
||
from pylops import LinearOperator | ||
from pylops.utils import deps | ||
from pylops.utils._internal import _value_or_sized_to_tuple | ||
from pylops.utils.decorators import reshaped | ||
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray | ||
from ucurv import * | ||
|
||
ucurv_message = deps.ucurv_import("the ucurv module") | ||
|
||
class UDCT(LinearOperator): | ||
def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): | ||
self.udct = udct(sz, cfg, complex, sparse) | ||
self.shape = (self.udct.len, np.prod(sz)) | ||
self.dtype = np.dtype(dtype) | ||
self.explicit = False | ||
self.rmatvec_count = 0 | ||
self.matvec_count = 0 | ||
def _matvec(self, x): | ||
img = x.reshape(self.udct.sz) | ||
band = ucurvfwd(img, self.udct) | ||
bvec = bands2vec(band) | ||
return bvec | ||
|
||
def _rmatvec(self, x): | ||
band = vec2bands(x, self.udct) | ||
recon = ucurvinv(band, self.udct) | ||
recon2 = recon.reshape(self.udct.sz) | ||
return recon2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pylops.signalprocessing import UDCT | ||
|
||
from ucurv import * | ||
|
||
import numpy as np | ||
|
||
eps = 1e-6 | ||
shapes = [ | ||
[[256, 256], ], | ||
[[32, 32, 32], ], | ||
[[16, 16, 16, 16], ] | ||
] | ||
|
||
configurations = [ | ||
[[[3, 3]], | ||
[[6, 6]], | ||
[[12, 12]], | ||
[[12, 12], [24, 24]], | ||
[[12, 12], [3, 3], [6, 6]], | ||
[[12, 12], [3, 3], [6, 6], [24, 24]], | ||
], | ||
[[[3, 3, 3]], | ||
[[6, 6, 6]], | ||
[[12, 12, 12]], | ||
[[12, 12, 12], [24, 24, 24]], | ||
# [[12, 12, 12], [3, 3, 3], [6, 6, 6]], | ||
# [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], | ||
], | ||
[[[3, 3, 3, 3]], | ||
# [[6, 6, 6, 6]], | ||
# [[12, 12, 12, 12]], | ||
# [[12, 12, 12, 12], [24, 24, 24, 24]], | ||
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], | ||
# [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], | ||
], | ||
] | ||
|
||
combinations = [ | ||
(shape, config) | ||
for shape_list, config_list in zip(shapes, configurations) | ||
for shape in shape_list | ||
for config in config_list | ||
] | ||
|
||
@pytest.mark.parametrize("shape, cfg", combinations) | ||
def test_ucurv(shape, cfg): | ||
data = np.random.rand(*shape) | ||
tf = udct(shape, cfg) | ||
band = ucurvfwd(data, tf) | ||
recon = ucurvinv(band, tf) | ||
are_close = np.all(np.isclose(data, recon, atol=eps)) | ||
assert(are_close == True) | ||
|
||
@pytest.mark.parametrize("shape, cfg", combinations) | ||
def test_vectorize(shape, cfg): | ||
data = np.random.rand(*shape) | ||
tf = udct(shape, cfg) | ||
band = ucurvfwd(data, tf) | ||
flat = bands2vec(band) | ||
unflat = vec2bands(flat, tf) | ||
recon = ucurvinv(band, tf) | ||
are_close = np.all(np.isclose(data, recon, atol=eps)) | ||
assert(are_close == True) | ||
|