Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BundleSeg exploration viewer for filtering #1035

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
3.10
>=3.9,<3.11
>=3.9,<3.12
165 changes: 106 additions & 59 deletions scilpy/segment/recobundlesx.py → scilpy/segment/bundleseg.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
# -*- coding: utf-8 -*-

import gc
import logging
from time import time
import warnings

from dipy.align.streamlinear import (BundleMinDistanceMetric,
StreamlineLinearRegistration)
from dipy.segment.fss import FastStreamlineSearch
from dipy.segment.fss import FastStreamlineSearch, nearest_from_matrix_col
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.distances import bundles_distances_mdf
from dipy.tracking.streamline import (select_random_set_of_streamlines,
transform_streamlines)
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np
from scipy.sparse import vstack

from scilpy.io.streamlines import reconstruct_streamlines_from_memmap

logger = logging.getLogger('BundleSeg')

class RecobundlesX(object):

def get_duration(start_time):
"""
Helper function to get the duration of a process.
"""
return np.round(time() - start_time, 2)


class BundleSeg(object):
"""
This class is a 'remastered' version of the Dipy Recobundles class.
Made to be used in sync with the voting_scheme.
Adapted in many way for HPC processing and increase control over
parameters and logging.
parameters and logger.
However, in essence they do the same thing.
https://github.com/nipy/dipy/blob/master/dipy/segment/bundles.py

Expand All @@ -33,13 +44,15 @@ class RecobundlesX(object):
clustering, Neuroimage, 2017.
"""

def __init__(self, memmap_filenames, clusters_indices, wb_centroids,
rng=None):
def __init__(self, memmap_filenames, transformation,
clusters_indices, wb_centroids, rng=None):
"""
Parameters
----------
memmap_filenames : tuple
tuple of filenames for the data, offsets and lengths.
transformation : numpy.ndarray
Transformation matrix to apply to the model streamlines.
clusters_indices: ArraySequence
ArraySequence containing the indices of the streamlines per
cluster.
Expand All @@ -50,16 +63,19 @@ def __init__(self, memmap_filenames, clusters_indices, wb_centroids,
If None then RandomState is initialized internally.
"""
self.memmap_filenames = memmap_filenames
self.transformation = transformation
self.wb_clusters_indices = clusters_indices
self.centroids = wb_centroids
self.wb_centroids = wb_centroids
self.rng = rng

# For memory management
self.wb_centroids._data = self.wb_centroids._data.astype(np.float16)

# For declaration outside of init
self.neighb_centroids = None
self.neighb_indices = None
self.models_streamlines = None
self.model_streamlines = None
self.model_centroids = None
self.final_pruned_indices = None

def recognize(self, model_bundle,
model_clust_thr=8, pruning_thr=8,
Expand All @@ -77,38 +93,39 @@ def recognize(self, model_bundle,
Define the transformation for the local SLR.
[translation, rigid, similarity, scaling]
identifier : str
Identify the current bundle being recognized for the logging.
Identify the current bundle being recognized for the logger.

Returns
-------
clusters : list
Streamlines that were recognized by Recobundles and these
parameters.
"""

self.model_streamlines = model_bundle

self._cluster_model_bundle(model_clust_thr,
identifier=identifier)

if not self._reduce_search_space(neighbors_reduction_thr=16):
if self._reduce_search_space(neighbors_reduction_thr=16) == 0:
if identifier:
logging.error('{0} did not find any neighbors in '
'the tractogram'.format(identifier))
return np.array([], dtype=np.uint32)
logger.error(f'{identifier} did not find any neighbors in '
'the tractogram')
return [], []

self._register_model_to_neighb(slr_transform_type=slr_transform_type)

# self._reduce_search_space(neighbors_reduction_thr=12)
if not self._reduce_search_space(neighbors_reduction_thr=12):
if self._reduce_search_space(neighbors_reduction_thr=14) == 0:
if identifier:
logging.error('{0} did not find any neighbors in '
'the tractogram'.format(identifier))
return np.array([], dtype=np.uint32)
logger.error(f'{identifier} did not find any neighbors in '
'the tractogram')
return [], []
del self.neighb_centroids, self.model_centroids
gc.collect()

self.prune_far_from_model(pruning_thr=pruning_thr)
pruned_indices, pruned_scores = self.prune_far_from_model(
pruning_thr=pruning_thr)

self.cleanup()
return self.get_final_pruned_indices()
return pruned_indices, pruned_scores

def _cluster_model_bundle(self, model_clust_thr, identifier):
"""
Expand All @@ -117,7 +134,7 @@ def _cluster_model_bundle(self, model_clust_thr, identifier):
Parameters
----------
model_clust_thr, float, distance in mm for clustering.
identifier, str, name of the bundle for logging.
identifier, str, name of the bundle for logger.
"""
thresholds = [30, 20, 15, model_clust_thr]
model_cluster_map = qbx_and_merge(self.model_streamlines, thresholds,
Expand All @@ -126,12 +143,13 @@ def _cluster_model_bundle(self, model_clust_thr, identifier):
verbose=False)

self.model_centroids = ArraySequence(model_cluster_map.centroids)
self.model_centroids._data = self.model_centroids._data.astype(
np.float16)

len_centroids = len(self.model_centroids)
if len_centroids > 1000:
logging.warning('Model {0} simplified at threshold '
'{1}mm with {2} centroids'.format(identifier,
str(model_clust_thr),
str(len_centroids)))
logger.warning(f'Model {identifier} simplified at threshold '
f'{model_clust_thr}mm with {len_centroids} centroids')

def _reduce_search_space(self, neighbors_reduction_thr=18):
"""
Expand All @@ -140,8 +158,9 @@ def _reduce_search_space(self, neighbors_reduction_thr=18):
:param neighbors_reduction_thr, float, distance in mm for thresholding
to discard distant streamlines.
"""

centroid_matrix = bundles_distances_mdf(self.model_centroids,
self.centroids).astype(np.float16)
self.wb_centroids).astype(np.float16)
centroid_matrix[centroid_matrix >
neighbors_reduction_thr] = np.inf

Expand All @@ -154,12 +173,14 @@ def _reduce_search_space(self, neighbors_reduction_thr=18):
self.neighb_indices.extend(self.wb_clusters_indices[i])
self.neighb_indices = np.array(self.neighb_indices, dtype=np.uint32)

self.neighb_centroids = [self.centroids[i]
for i in close_clusters_indices]
self.neighb_centroids = ArraySequence([self.wb_centroids[i]
for i in close_clusters_indices])
self.neighb_centroids._data = self.neighb_centroids._data.astype(
np.float16)

return self.neighb_indices.size

def _register_model_to_neighb(self, select_model=1000, select_target=1000,
def _register_model_to_neighb(self, select_model=250, select_target=250,
slr_transform_type='similarity'):
"""
Parameters
Expand All @@ -179,6 +200,7 @@ def _register_model_to_neighb(self, select_model=1000, select_target=1000,
"""
possible_slr_transform_type = {'translation': 0, 'rigid': 1,
'similarity': 2, 'scaling': 3}

static = select_random_set_of_streamlines(self.model_centroids,
select_model, self.rng)
moving = select_random_set_of_streamlines(self.neighb_centroids,
Expand Down Expand Up @@ -233,8 +255,12 @@ def _register_model_to_neighb(self, select_model=1000, select_target=1000,
bounds=bounds_dof[:9],
num_threads=1)
slm = slr.optimize(static, moving)
self.model_centroids = transform_streamlines(self.model_centroids,
np.linalg.inv(slm.matrix))

# Apply the transformation to the model streamlines
self.model_streamlines = transform_streamlines(self.model_streamlines,
np.linalg.inv(slm.matrix))
self.model_streamlines._data = self.model_streamlines._data.astype(
np.float16)

def prune_far_from_model(self, pruning_thr=10):
"""
Expand All @@ -248,37 +274,58 @@ def prune_far_from_model(self, pruning_thr=10):
"""
# Neighbors can be refined since the search space is smaller
t0 = time()

neighb_streamlines = reconstruct_streamlines_from_memmap(
self.memmap_filenames, self.neighb_indices, strs_dtype=np.float16)

# Typically the neighbors is bigger than the model, so we flip the
# FSS to be more memory efficient
with warnings.catch_warnings(record=True) as _:
fss = FastStreamlineSearch(neighb_streamlines,
fss = FastStreamlineSearch(self.model_streamlines,
pruning_thr, resampling=12)
dist_mat = fss.radius_search(self.model_streamlines,
pruning_thr)

logging.debug("Fast search took of dimensions {0}: {1} sec.".format(
dist_mat.shape, np.round(time() - t0, 2)))

sparse_dist_mat = np.abs(dist_mat.tocsr())
sparse_dist_vec = np.squeeze(np.max(sparse_dist_mat, axis=0).toarray())
pruned_indices = np.where(sparse_dist_vec > 1e-6)[0]

# Since the neighbors were clustered, a mapping of indices is neccesary
self.final_pruned_indices = self.neighb_indices[pruned_indices].astype(
np.uint32)

return self.final_pruned_indices

def get_final_pruned_indices(self):
"""
Public getter for the final indices recognize by the algorithm.
"""
return self.final_pruned_indices
CHUNK_SIZE = 1000
dist_mat_list = []
for chuck_id in range(0, len(neighb_streamlines), CHUNK_SIZE):
tmp_dist_mat = fss.radius_search(
neighb_streamlines[chuck_id:chuck_id+CHUNK_SIZE],
pruning_thr)
tmp_dist_mat.data = tmp_dist_mat.data.astype(np.float16)
dist_mat_list.append(tmp_dist_mat.copy())
del tmp_dist_mat
gc.collect()

dist_mat = vstack(dist_mat_list)
for tmp_dist_mat in dist_mat_list:
del tmp_dist_mat
gc.collect()
dist_mat.data = dist_mat.data.astype(np.float16)
dist_mat = dist_mat.T

logger.debug(f'Fast search took of dimensions {dist_mat.shape}: '
f'{get_duration(t0)} sec.')
if dist_mat.size == 0 or dist_mat.shape[1] <= 1:
return [], []

# Identify the closest neighbors (remove the zeros, not matched)
np.absolute(dist_mat.data, out=dist_mat.data)
non_zero_ids, _, scores = nearest_from_matrix_col(dist_mat)

del dist_mat, neighb_streamlines
del fss.ref_slines, fss.bin_dict, fss
gc.collect()

# If no streamlines were recognized, return an empty array
if len(non_zero_ids) != 0:
# Since the neighbors were clustered, a mapping of indices is neccesary
final_pruned_indices = self.neighb_indices[non_zero_ids].astype(
np.uint32)
final_pruned_scores = scores.astype(np.float16)
return final_pruned_indices, final_pruned_scores
else:
return [], []

def cleanup(self):
del self.neighb_centroids
del self.neighb_indices
del self.model_streamlines
del self.model_centroids
for indices in [self.neighb_indices, self.wb_clusters_indices]:
if indices is not None:
del indices
del self.model_streamlines._data, self.model_streamlines
Loading