Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SMALL] Add support for Undecimated Wavelet Transform, with right adjoint operator #38

Merged
merged 9 commits into from
Oct 3, 2019
23 changes: 14 additions & 9 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,30 @@ install:
- git clone https://github.com/CEA-COSMIC/pysap-data.git $HOME/.local/share/pysap/pysap-data
- ln -s $HOME/.local/share/pysap/pysap-data/pysap-data/* $HOME/.local/share/pysap
- ls -l $HOME/.local/share/pysap
- pip install numpy
- pip install --upgrade pip
- pip install matplotlib
- pip install cython
- pip install coverage nose pytest pytest-cov
- pip install coveralls
- pip install pycodestyle
- pip install git+https://github.com/CEA-COSMIC/pysap@master
- pip install git+https://github.com/ghisvail/pyNFFT.git
- pip install -b $TRAVIS_BUILD_DIR/build -t $TRAVIS_BUILD_DIR/install --no-clean --upgrade .
- ls $TRAVIS_BUILD_DIR/install
- export PYTHONPATH=$TRAVIS_BUILD_DIR/install:$PYTHONPATH
- pip install git+https://github.com/chaithyagr/ModOpt.git
- pushd ../
- git clone https://github.com/CEA-COSMIC/pysap
- cd pysap
- python setup.py install
- if [ $TRAVIS_PYTHON_VERSION == "3.5" ]; then
export PATH=$PATH:$TRAVIS_BUILD_DIR/build/temp.linux-x86_64-3.5/extern/bin;
export PATH=$PATH:$PWD/build/temp.linux-x86_64-3.5/extern/bin;
fi
- if [ $TRAVIS_PYTHON_VERSION == "3.6" ]; then
export PATH=$PATH:$TRAVIS_BUILD_DIR/build/temp.linux-x86_64-3.6/extern/bin;
export PATH=$PATH:$PWD/build/temp.linux-x86_64-3.6/extern/bin;
fi
- if [ $TRAVIS_PYTHON_VERSION == "3.7" ]; then
export PATH=$PATH:$TRAVIS_BUILD_DIR/build/temp.linux-x86_64-3.7/extern/bin;
export PATH=$PATH:$PWD/build/temp.linux-x86_64-3.7/extern/bin;
fi
- popd
- export PYTHONPATH=$TRAVIS_BUILD_DIR/install:$PYTHONPATH
- pip install git+https://github.com/ghisvail/pyNFFT.git
- pip install -b $TRAVIS_BUILD_DIR/build -t $TRAVIS_BUILD_DIR/install --no-clean --upgrade .

script:
- python setup.py test
Expand Down
9 changes: 9 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ This work is made available by a community of people, amoung which the
CEA Neurospin UNATI and CEA CosmoStat laboratories, in particular A. Grigis,
J.-L. Starck, P. Ciuciu, and S. Farrens.

Installation instructions
===============

Install python-pySAP using `pip install python-pySAP`. Later install pysap-mri by calling setup.py
Note: If you want to use undecimated wavelet transform, please point the `$PATH` environment variable to
pysap external binary directory:

`export PATH=$PATH:/path-to-pysap/build/temp.linux-x86_64-<PYTHON_VERSION>/extern/bin/`

Important links
===============

Expand Down
1 change: 1 addition & 0 deletions mri/numerics/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@

# Package import
from mri.reconstruct.linear import WaveletN
from mri.reconstruct.linear import WaveletUD2
195 changes: 190 additions & 5 deletions mri/reconstruct/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@


# Package import
from modopt.signal.wavelet import get_mr_filters, filter_convolve
import pysap
from pysap.base.utils import flatten
from pysap.base.utils import unflatten

# Third party import
import numpy
from joblib import Parallel, delayed
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np


class WaveletN(object):
Expand Down Expand Up @@ -69,7 +71,7 @@ def op(self, data):
coeffs: ndarray
the wavelet coefficients.
"""
if isinstance(data, numpy.ndarray):
if isinstance(data, np.ndarray):
data = pysap.Image(data=data)
self.transform.data = data
self.transform.analysis()
Expand Down Expand Up @@ -114,13 +116,196 @@ def l2norm(self, shape):
the L2 norm.
"""
# Create fake data
shape = numpy.asarray(shape)
shape = np.asarray(shape)
shape += shape % 2
fake_data = numpy.zeros(shape)
fake_data = np.zeros(shape)
fake_data[tuple(zip(shape // 2))] = 1

# Call mr_transform
data = self.op(fake_data)

# Compute the L2 norm
return numpy.linalg.norm(data)
return np.linalg.norm(data)


class WaveletUD2(object):
"""The wavelet undecimated operator using pysap wrapper.
"""
def __init__(self, wavelet_id=24, nb_scale=4, multichannel=False,
n_cpu=1, backend='threading', verbose=0):
"""Init function for Undecimated wavelet transform

Parameters
-----------
wavelet_id: int, default 24 = undecimated (bi-) orthogonal transform
ID of wavelet being used
nb_scale: int, default 4
the number of scales in the decomposition.
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
multichannel: bool, default False
Boolean value to indicate if the incoming data is from
multiple-channels
n_cpu: int, default 0
Number of CPUs to run on. Only applicable if multichannel=True.
backend: 'threading' | 'multiprocessing', default 'threading'
Denotes the backend to use for parallel execution across
multiple channels.
verbose: int, default 0
The verbosity level for Parallel operation from joblib
Private Variables:
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
_has_run: Checks if the get_mr_filters was called already
"""
self.wavelet_id = wavelet_id
self.multichannel = multichannel
self.nb_scale = nb_scale
self.n_cpu = n_cpu
self.backend = backend
self.verbose = verbose
self._opt = [
'-t{}'.format(self.wavelet_id),
'-n{}'.format(self.nb_scale),
]
self._has_run = False
self.coeffs_shape = None

def _get_filters(self, shape):
"""Function to get the Wavelet coefficients of Delta[0][0].
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
This function is called only once and later the
wavelet coefficients are obtained by convolving these coefficients
with input Data
"""
self.transform = get_mr_filters(
tuple(shape),
opt=self._opt,
coarse=True,
)
self._has_run = True

def _op(self, data):
""" Define the wavelet operator for single channel.
This is internal function that returns wavelet coefficients for a
single channel
Parameters
----------
data: ndarray or Image
input 2D data array.

Returns
-------
coeffs: ndarray
the wavelet coefficients.
"""
coefs_real = filter_convolve(data.real, self.transform)
coefs_imag = filter_convolve(data.imag, self.transform)
coeffs, coeffs_shape = flatten(
coefs_real + 1j * coefs_imag)
return coeffs, coeffs_shape

def op(self, data):
""" Define the wavelet operator.

This method returns the input data convolved with the wavelet filter.

Parameters
----------
data: ndarray or Image
input 2D data array.

Returns
-------
coeffs: ndarray
the wavelet coefficients.
"""
if not self._has_run:
if self.multichannel:
self._get_filters(list(data.shape)[1:])
else:
self._get_filters(data.shape)
if self.multichannel:
coeffs, self.coeffs_shape = zip(*Parallel(n_jobs=self.n_cpu,
backend=self.backend,
verbose=self.verbose)(
delayed(self._op)
(data[i])
for i in np.arange(data.shape[0])))
coeffs = np.asarray(coeffs)
else:
coeffs, self.coeffs_shape = self._op(data)
return coeffs

def _adj_op(self, coefs, coeffs_shape):
"""" Define the wavelet adjoint operator.

This method returns the reconstructed image for single channel.

Parameters
----------
coeffs: ndarray
the wavelet coefficients.
coeffs_shape: ndarray
The shape of coefficients to unflatten before adjoint operation
Returns
-------
data: ndarray
the reconstructed data.
"""
data_real = filter_convolve(
np.squeeze(unflatten(coefs.real, coeffs_shape)),
self.transform, filter_rot=True)
data_imag = filter_convolve(
np.squeeze(unflatten(coefs.imag, coeffs_shape)),
self.transform, filter_rot=True)
return data_real + 1j * data_imag

def adj_op(self, coefs):
""" Define the wavelet adjoint operator.

This method returns the reconsructed image.

Parameters
----------
coeffs: ndarray
the wavelet coefficients.

Returns
-------
data: ndarray
the reconstructed data.
"""
if not self._has_run:
raise RuntimeError(
"`op` must be run before `adj_op` to get the data shape",
)
if self.multichannel:
images = Parallel(n_jobs=self.n_cpu,
backend=self.backend,
verbose=self.verbose)(
delayed(self._adj_op)
(coefs[i], self.coeffs_shape[i])
for i in np.arange(coefs.shape[0]))
images = np.asarray(images)
else:
images = self._adj_op(coefs, self.coeffs_shape)
return images

def l2norm(self, shape):
""" Compute the L2 norm.
Parameters
----------
shape: uplet
the data shape.
Returns
-------
norm: float
the L2 norm.
"""
# Create fake data
shape = np.asarray(shape)
shape += shape % 2
fake_data = np.zeros(shape)
fake_data[tuple(zip(shape // 2))] = 1

# Call mr_transform
data = self.op(fake_data)

# Compute the L2 norm
return np.linalg.norm(data)
43 changes: 41 additions & 2 deletions mri/test/test_wavelet_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@
import numpy

# Package import
from mri.reconstruct.linear import WaveletN
from mri.reconstruct.linear import WaveletN, WaveletUD2


class TestAdjointOperatorWaveletTransform(unittest.TestCase):
""" Test the adjoint operator of the Wavelets both for 2D and 3D.
"""

def setUp(self):
""" Set the number of iterations.
""" Setup variables:
N = Image size
max_iter = Number of iterations to test
num_channels = Number of channels to be tested with for
multichannel tests
"""
self.N = 64
self.max_iter = 10
self.num_channels = 10

def test_Wavelet2D_ISAP(self):
"""Test the adjoint operator for the 2D Wavelet transform
Expand Down Expand Up @@ -80,6 +85,40 @@ def test_Wavelet3D_PyWt(self):
numpy.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print(" Wavelet3 adjoint test passes")

def test_Wavelet_UD_2D(self):
"""Test the adjoint operation for Undecimated wavelet
"""
for i in range(self.max_iter):
print("Process Wavelet Undecimated test '{0}'...", i)
wavelet_op = WaveletUD2(nb_scale=4)
img = (numpy.random.randn(self.N, self.N) +
1j * numpy.random.randn(self.N, self.N))
f_p = wavelet_op.op(img)
f = (numpy.random.randn(*f_p.shape) +
1j * numpy.random.randn(*f_p.shape))
i_p = wavelet_op.adj_op(f)
x_d = numpy.vdot(img, i_p)
x_ad = numpy.vdot(f_p, f)
numpy.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print("Undecimated Wavelet 2D adjoint test passes")

def test_Wavelet_UD_2D_Multichannel(self):
"""Test the adjoint operation for Undecmated wavelet Transform in
multichannel case"""
for i in range(self.max_iter):
print("Process Wavelet Undecimated test '{0}'...", i)
wavelet_op = WaveletUD2(nb_scale=4, multichannel=True, n_cpu=2)
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
img = (numpy.random.randn(self.num_channels, self.N, self.N) +
1j * numpy.random.randn(self.num_channels, self.N, self.N))
f_p = wavelet_op.op(img)
f = (numpy.random.randn(*f_p.shape) +
1j * numpy.random.randn(*f_p.shape))
i_p = wavelet_op.adj_op(f)
x_d = numpy.vdot(img, i_p)
x_ad = numpy.vdot(f_p, f)
numpy.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
print("Undecimated Wavelet 2D adjoint test passes for multichannel")


if __name__ == "__main__":
unittest.main()