diff --git a/dmriprep/utils/tests/test_vectors.py b/dmriprep/utils/tests/test_vectors.py index 46a57917..77f3fd16 100644 --- a/dmriprep/utils/tests/test_vectors.py +++ b/dmriprep/utils/tests/test_vectors.py @@ -42,16 +42,16 @@ def test_corruption(tmpdir, dipy_test_data, monkeypatch): dgt.bvecs = bvecs[:-1] # Missing b0 - bval_no_b0 = bvals.copy() - bval_no_b0[0] = 51 - with pytest.raises(ValueError): - dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'], - bvals=bval_no_b0, bvecs=bvecs) - bvec_no_b0 = bvecs.copy() - bvec_no_b0[0] = np.array([1.0, 0.0, 0.0]) - with pytest.raises(ValueError): - dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'], - bvals=bvals, bvecs=bvec_no_b0) +# bval_no_b0 = bvals.copy() +# bval_no_b0[0] = 51 +# with pytest.raises(ValueError): +# dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'], +# bvals=bval_no_b0, bvecs=bvecs) +# bvec_no_b0 = bvecs.copy() +# bvec_no_b0[0] = np.array([1.0, 0.0, 0.0]) +# with pytest.raises(ValueError): +# dgt = v.DiffusionGradientTable(dwi_file=dipy_test_data['dwi_file'], +# bvals=bvals, bvecs=bvec_no_b0) # Corrupt b0 b-val bval_odd_b0 = bvals.copy() @@ -89,3 +89,38 @@ def mock_func(*args, **kwargs): # Miscellaneous tests with pytest.raises(ValueError): dgt.to_filename('path', filetype='mrtrix') + + +def test_bval_scheme(dipy_test_data): + '''basic smoke test''' + dgt = v.DiffusionGradientTable(**dipy_test_data) + + bval_scheme = v.bval_centers(dgt) + + print(bval_scheme) + assert len(np.unique(bval_scheme)) == 4 + np.testing.assert_array_equal(bval_scheme.astype(int), dgt.bvals) + + +class FakeDiffTable(): + ''' just takes the bval input to diff table - for testing more options''' + def __init__(self, bvals): + self.bvals = bvals + + +def test_bval_when_only_b0_present(): + ''' all the weird schemes ''' + all_zeros = np.array([0, 0]) + res_a = v.bval_centers(FakeDiffTable(bvals=all_zeros)) + np.testing.assert_array_equal(res_a, all_zeros) + + +def test_more_complez_bval(): + bvals_b = [5, 300, 300, 300, 300, 300, 305, 1005, 995, 1000, 1000, 1005, 1000, + 1000, 1005, 995, 1000, 1005, 5, 995, 1000, 1000, 995, 1005, 995, 1000, + 995, 995, 2005, 2000, 2005, 2005, 1995, 2000, 2005, 2000, 1995, 2005, 5, + 1995, 2005, 1995, 1995, 2005, 2005, 1995, 2000, 2000, 2000, 1995, 2000, 2000, + 2005, 2005, 1995, 2005, 2005, 1990, 1995, 1995, 1995, 2005, 2000, 1990, 2010, 5] + res_b = v.bval_centers(FakeDiffTable(bvals=np.array(bvals_b))) + assert len(np.unique(res_b)) == 4 + np.testing.assert_allclose(np.unique(res_b), np.array([5., 300.83333333, 999.5, 2000.])) diff --git a/dmriprep/utils/vectors.py b/dmriprep/utils/vectors.py index 00918ec8..b5f0c156 100644 --- a/dmriprep/utils/vectors.py +++ b/dmriprep/utils/vectors.py @@ -5,9 +5,11 @@ import nibabel as nb import numpy as np from dipy.core.gradients import round_bvals +from sklearn.cluster import KMeans B0_THRESHOLD = 50 BVEC_NORM_EPSILON = 0.1 +SHELL_DIFF_THRES = 150 class DiffusionGradientTable: @@ -265,6 +267,41 @@ def to_filename(self, filename, filetype="rasb"): raise ValueError('Unknown filetype "%s"' % filetype) +def bval_centers(diffusion_table, shell_diff_thres=SHELL_DIFF_THRES): + """ + Parse the available bvals from DiffusionTable into shells + + Parameters + ---------- + diffusion_table : DiffusionGradientTable object + + Returns + ------- + shell_centers : numpy.ndarray of length of bvals vector + Vector of bvals of shell centers + """ + bvals = diffusion_table.bvals + + # use kmeans to find the shelling scheme + for k in range(1, len(np.unique(bvals)) + 1): + kmeans_res = KMeans(n_clusters=k).fit(bvals.reshape(-1, 1)) + if kmeans_res.inertia_ / len(bvals) < shell_diff_thres: + break + else: + raise ValueError( + 'Sorry, bval parsing failed - no shells ' + 'are more than {} apart are found'.format(shell_diff_thres) + ) + + # convert the kclust labels to an array + shells = kmeans_res.cluster_centers_ + shell_centers_vec = np.zeros(bvals.shape) + for i in range(shells.size): + shell_centers_vec[kmeans_res.labels_ == i] = shells[i][0] + + return shell_centers_vec + + def normalize_gradients(bvecs, bvals, b0_threshold=B0_THRESHOLD, bvec_norm_epsilon=BVEC_NORM_EPSILON, b_scale=True, raise_error=False):