Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - extra bval parsing for number of shells #73

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions dmriprep/utils/tests/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,31 @@ def mock_func(*args, **kwargs):
# Miscellaneous tests
with pytest.raises(ValueError):
dgt.to_filename('path', filetype='mrtrix')

def test_bval_scheme(dipy_test_data):
'''basic smoke test'''
bvals = dipy_test_data['bvals']

bval_scheme = v.BVALScheme(bvals = bvals)

print(bval_scheme)
assert bval_scheme.n_shells == 3 ## just see if if can get here..

def test_bval_when_only_b0_present():
''' all the weird schemes '''
res_a = v.BVALScheme(bvals = np.array([0,0]))
print(res_a)
assert res_a.n_b0 == 2
assert res_a.n_shells == 0

def test_more_complez_bval():
bvals_b = [5, 300, 300, 300, 300, 300, 305, 1005, 995, 1000, 1000, 1005, 1000,
1000, 1005, 995, 1000, 1005, 5, 995, 1000, 1000, 995, 1005, 995, 1000,
995, 995, 2005, 2000, 2005, 2005, 1995, 2000, 2005, 2000, 1995, 2005, 5,
1995, 2005, 1995, 1995, 2005, 2005, 1995, 2000, 2000, 2000, 1995, 2000, 2000,
2005, 2005, 1995, 2005, 2005, 1990, 1995, 1995, 1995, 2005, 2000, 1990, 2010, 5]
res_b = v.BVALScheme(bvals = np.array(bvals_b))
print(res_b)
assert res_b.n_b0 == 4
assert res_b.n_shells == 3

101 changes: 101 additions & 0 deletions dmriprep/utils/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import nibabel as nb
import numpy as np
from dipy.core.gradients import round_bvals
from sklearn.cluster import KMeans

B0_THRESHOLD = 50
BVEC_NORM_EPSILON = 0.1
SHELL_DIFF_THRES = 150


class DiffusionGradientTable:
Expand Down Expand Up @@ -177,6 +179,105 @@ def to_filename(self, filename, filetype='rasb'):
else:
raise ValueError('Unknown filetype "%s"' % filetype)

class BVALScheme:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably go with a single function that accepts a DiffusionGradientTable object for this.

The function would return only a list of nonzero shell centers and the number of shells would be just the len() of that (which can be easily calculated by the user).

WDYT?

"""Data structure for bval scheme."""

def __init__(self, bvals = None, shell_diff_thres = SHELL_DIFF_THRES):
"""
Parse the available bvals into shells

Parameters
----------
bvals : numpy.ndarray of b-values
"""

self._shell_diff_thres = shell_diff_thres
self._bvals = np.array(bvals)

self._kclust = self._k_cluster_result()

def _k_cluster_result(self):
''' determine the shell system by running k clustering return a dict with masks separated by shell'''
for k in range(1,len(np.unique(self._bvals)) + 1):
kmeans_res = KMeans(n_clusters = k).fit(self._bvals.reshape(-1,1))
if kmeans_res.inertia_/len(self._bvals) < self._shell_diff_thres:
return kmeans_res
print('Sorry, bval parsing failed - no shells are more than {} apart found'.format(shell_diff_thres))
return None

@property
def bvals(self):
"""Get the N b-values."""
return self._bvals

@property
def shells(self):
'''return sorted shells rounded to nearest 100'''
shells = np.round(np.squeeze(self._kclust.cluster_centers_),-2)
if shells.size == 1:
return np.array(shells).reshape(1,-1) #convert back to iterable type
else:
return np.sort(shells)

@property
def n_shells(self):
''' returns number of non-zero shells'''
return sum(self.shells != 0)

def get_shell_centers(self, shell = 'all'):
''' returns non rounded shell centers'''
all_centers = np.squeeze(self._kclust.cluster_centers_)
if all_centers.size > 1:
all_centers = np.sort(all_centers)
else:
all_centers = np.array(all_centers).reshape(1,-1)
if shell == 'all':
return all_centers
elif shell in self.shells:
return all_centers[self.shells == shell]
else:
print("could not find shell {} in bvals".format(shell))
return None

def get_shell_mask(self, shell):
''' returns the mask for a given shell'''
shell_center = self.get_shell_centers(shell = shell)
clustid = np.where(self._kclust.cluster_centers_ == shell_center)[0]
mask = self._kclust.labels_ == clustid
return mask

def get_n_directions_in_shell(self, shell):
''' returns the number of directions in a shell'''
return sum(self.get_shell_mask(shell))

@property
def b0_mask(self):
return self.get_shell_mask(shell = 0)

@property
def n_b0(self):
'''returns number of b0s'''
return np.sum(self.b0_mask)

@property
def total_directions(self):
'''prints number of non-b0 directions (assuming they are unique)'''
return np.sum(np.invert(self.b0_mask))

def __str__(self):
''' prints pretty string summary of bvals '''
shell_list = []
for shell in self.shells:
if shell > 0:
shell_list.append("{}:{}".format(int(shell),
self.get_n_directions_in_shell(shell)))
shell_str = "{} B0s + {} directions in {} shell(s) | B-value:n_directions {}".format(
self.n_b0,
self.total_directions,
self.n_shells,
", ".join(shell_list))
return shell_str


def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD,
bvec_norm_epsilon=BVEC_NORM_EPSILON, b_scale=True):
Expand Down