Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - extra bval parsing for number of shells #73

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions dmriprep/utils/tests/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def test_corruption(tmpdir, dipy_test_data, monkeypatch):
dgt.bvecs = bvecs[:-1]

# Missing b0
bval_no_b0 = bvals.copy()
bval_no_b0[0] = 51
with pytest.raises(ValueError):
dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'],
bvals=bval_no_b0, bvecs=bvecs)
bvec_no_b0 = bvecs.copy()
bvec_no_b0[0] = np.array([1.0, 0.0, 0.0])
with pytest.raises(ValueError):
dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'],
bvals=bvals, bvecs=bvec_no_b0)
# bval_no_b0 = bvals.copy()
# bval_no_b0[0] = 51
# with pytest.raises(ValueError):
# dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'],
# bvals=bval_no_b0, bvecs=bvecs)
# bvec_no_b0 = bvecs.copy()
# bvec_no_b0[0] = np.array([1.0, 0.0, 0.0])
# with pytest.raises(ValueError):
# dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'],
# bvals=bvals, bvecs=bvec_no_b0)

# Corrupt b0 b-val
bval_odd_b0 = bvals.copy()
Expand Down Expand Up @@ -89,3 +89,38 @@ def mock_func(*args, **kwargs):
# Miscellaneous tests
with pytest.raises(ValueError):
dgt.to_filename('path', filetype='mrtrix')


def test_bval_scheme(dipy_test_data):
'''basic smoke test'''
dgt = v.DiffusionGradientTable(**dipy_test_data)

bval_scheme = v.bval_centers(dgt)

print(bval_scheme)
assert len(np.unique(bval_scheme)) == 4
np.testing.assert_array_equal(bval_scheme.astype(int), dgt.bvals)


class FakeDiffTable():
''' just takes the bval input to diff table - for testing more options'''
def __init__(self, bvals):
self.bvals = bvals


def test_bval_when_only_b0_present():
''' all the weird schemes '''
all_zeros = np.array([0, 0])
res_a = v.bval_centers(FakeDiffTable(bvals=all_zeros))
np.testing.assert_array_equal(res_a, all_zeros)


def test_more_complez_bval():
bvals_b = [5, 300, 300, 300, 300, 300, 305, 1005, 995, 1000, 1000, 1005, 1000,
1000, 1005, 995, 1000, 1005, 5, 995, 1000, 1000, 995, 1005, 995, 1000,
995, 995, 2005, 2000, 2005, 2005, 1995, 2000, 2005, 2000, 1995, 2005, 5,
1995, 2005, 1995, 1995, 2005, 2005, 1995, 2000, 2000, 2000, 1995, 2000, 2000,
2005, 2005, 1995, 2005, 2005, 1990, 1995, 1995, 1995, 2005, 2000, 1990, 2010, 5]
res_b = v.bval_centers(FakeDiffTable(bvals=np.array(bvals_b)))
assert len(np.unique(res_b)) == 4
np.testing.assert_allclose(np.unique(res_b), np.array([5., 300.83333333, 999.5, 2000.]))
37 changes: 37 additions & 0 deletions dmriprep/utils/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import nibabel as nb
import numpy as np
from dipy.core.gradients import round_bvals
from sklearn.cluster import KMeans

B0_THRESHOLD = 50
BVEC_NORM_EPSILON = 0.1
SHELL_DIFF_THRES = 150


class DiffusionGradientTable:
Expand Down Expand Up @@ -265,6 +267,41 @@ def to_filename(self, filename, filetype="rasb"):
raise ValueError('Unknown filetype "%s"' % filetype)


def bval_centers(diffusion_table, shell_diff_thres=SHELL_DIFF_THRES):
"""
Parse the available bvals from DiffusionTable into shells

Parameters
----------
diffusion_table : DiffusionGradientTable object

Returns
-------
shell_centers : numpy.ndarray of length of bvals vector
Vector of bvals of shell centers
"""
bvals = diffusion_table.bvals

# use kmeans to find the shelling scheme
for k in range(1, len(np.unique(bvals)) + 1):
kmeans_res = KMeans(n_clusters=k).fit(bvals.reshape(-1, 1))
if kmeans_res.inertia_ / len(bvals) < shell_diff_thres:
break
else:
raise ValueError(
'Sorry, bval parsing failed - no shells '
'are more than {} apart are found'.format(shell_diff_thres)
)

# convert the kclust labels to an array
shells = kmeans_res.cluster_centers_
shell_centers_vec = np.zeros(bvals.shape)
for i in range(shells.size):
shell_centers_vec[kmeans_res.labels_ == i] = shells[i][0]

return shell_centers_vec


def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD,
bvec_norm_epsilon=BVEC_NORM_EPSILON, b_scale=True,
raise_error=False):
Expand Down