Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
enh: externalize sorting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
esavary committed Apr 4, 2024
1 parent a1bcacd commit 3343f34
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 37 deletions.
44 changes: 7 additions & 37 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from eddymotion.data.splitting import lovo_split
from eddymotion.model import ModelFactory
from utils import sort_dwdata_indices, SortingStrategy


class EddyMotionEstimator:
Expand Down Expand Up @@ -88,22 +89,20 @@ def fit(

align_kwargs = align_kwargs or {}

index_order = _sort_dwdata_indices(seed, len(dwdata))
index_order = sort_dwdata_indices(
dwdata, SortingStrategy.RANDOM, seed=None
)

if "num_threads" not in align_kwargs and omp_nthreads is not None:
align_kwargs["num_threads"] = omp_nthreads

n_iter = len(models)
for i_iter, model in enumerate(models):
reg_target_type = (
"dwi" if model.lower() not in ("b0", "s0", "avg", "average", "mean") else "b0"
bmask_img = _prepare_brainmask_data(
dwdata.brainmask,
dwdata.affine
)

# When downsampling these need to be set per-level
bmask_img = _prepare_brainmask_data(dwdata.brainmask, dwdata.affine)

_prepare_kwargs(dwdata, kwargs)

single_model = model.lower() in (
"b0",
"s0",
Expand Down Expand Up @@ -241,35 +240,6 @@ def _to_nifti(data, affine, filename, clip=True):
nii.to_filename(filename)


def _sort_dwdata_indices(seed, dwi_vol_count):
"""Sort the DWI data volume indices.
Parameters
----------
seed : :obj:`int` or :obj:`bool`
Seed the random number generator. If an integer, the value is used to initialize the
generator; if ``True``, the arbitrary value of ``20210324`` is used to initialize it.
dwi_vol_count : :obj:`int`
Number of DWI volumes.
Returns
-------
index_order : :obj:`numpy.ndarray`
Index order.
"""

_seed = None
if seed or seed == 0:
_seed = 20210324 if seed is True else seed

rng = np.random.default_rng(_seed)

index_order = np.arange(dwi_vol_count)
rng.shuffle(index_order)

return index_order


def _prepare_brainmask_data(brainmask, affine):
"""Prepare the brainmask data: save the data to disk.
Expand Down
173 changes: 173 additions & 0 deletions src/eddymotion/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2022 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Utils to sort the DWI data volume indices """

from enum import Enum
import numpy as np


class SortingStrategy(Enum):
"""
Enum class representing different sorting strategies.
Available sorting strategies:
- LINEAR: Sorts the items in a linear order.
- RANDOM: Sorts the items in a random order.
- BVALUE: Sorts the items based on their b-value.
- CENTRALSYM: Sorts the items based on their central symmetry.
"""
LINEAR = "linear"
RANDOM = "random"
BVALUE = "bvalue"
CENTRALSYM = "centralsym"


def sort_dwdata_indices(dwdata, strategy, seed=None):
"""Sort the DWI data volume indices following the given strategy.
Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI`
DWI dataset, represented by this tool's internal type.
strategy : :obj:`~eddymotion.utils.SortingStrategy`
The sorting strategy to be used. Available options are:
- SortingStrategy.LINEAR: Sort the indices linearly.
- SortingStrategy.RANDOM: Sort the indices randomly.
- SortingStrategy.BVALUE: Sort the indices based on the last column of gradients in ascending order.
- SortingStrategy.CENTRALSYM: Sort the indices in a central symmetric manner.
seed : :obj:`int` or :obj:`bool`, optional
Seed the random number generator. If an integer, the value is used to
initialize the generator; if ``True``, the arbitrary value
of ``20210324`` is used to initialize it.
Returns
-------
index_order : :obj:`numpy.ndarray`
The sorted index order.
"""
if strategy == SortingStrategy.LINEAR:
return linear_action(dwdata)
elif strategy == SortingStrategy.RANDOM:
return random_action(dwdata, seed)
elif strategy == SortingStrategy.BVALUE:
return bvalue_action(dwdata)
elif strategy == SortingStrategy.CENTRALSYM:
return centralsym_action(dwdata)
else:
raise ValueError("Invalid sorting strategy")


def linear_action(dwdata):
"""
Sort the DWI data volume indices linearly
Parameters:
dwdata : :obj:`~eddymotion.dmri.DWI`
DWI dataset, represented by this tool's internal type.
Returns:
index_order : :obj:`numpy.ndarray`
The sorted index order.
"""
index_order = np.arange(len(dwdata))

return index_order


def random_action(dwdata, seed=None):
"""Sort the DWI data volume indices.
Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI`
DWI dataset, represented by this tool's internal type.
seed : :obj:`int` or :obj:`bool`, optional
Seed the random number generator. If an integer, the value is used to
initialize the generator; if ``True``, the arbitrary value
of ``20210324`` is used to initialize it.
Returns
-------
index_order : :obj:`numpy.ndarray`
The sorted index order.
"""

_seed = None
if seed or seed == 0:
_seed = 20210324 if seed is True else seed

rng = np.random.default_rng(_seed)

index_order = np.arange(len(dwdata))
rng.shuffle(index_order)

return index_order


def bvalue_action(dwdata):
"""
Sort the DWI data volume indices in ascending order based on the last
column of gradients.
Parameters:
dwdata : :obj:`~eddymotion.dmri.DWI`
DWI dataset, represented by this tool's internal type.
Returns:
numpy.ndarray: The sorted index order.
"""
last_column = dwdata.gradients[:, -1]
index_order = np.argsort(last_column)
return index_order


def centralsym_action(dwdata):
"""
Sort the DWI data volume indices in a central symmetric manner.
Parameters:
dwdata : :obj:`~eddymotion.dmri.DWI`
DWI dataset, represented by this tool's internal type.
Returns:
numpy.ndarray: The sorted index order.
"""
old_index = np.arange(len(dwdata))

index_order = old_index.copy()
if len(old_index) % 2 == 0:
middle_point = int(len(old_index) / 2-1)
index_order[0] = old_index[middle_point]

for i in np.arange(1, middle_point+1):
index_order[2*i-1] = old_index[middle_point + i]
index_order[2*i] = old_index[middle_point - i]
else:
middle_point = int(len(old_index) / 2)
index_order[0] = old_index[middle_point]
for i in np.arange(1, middle_point+1):
index_order[2*i-1] = old_index[middle_point + i]
index_order[2*i] = old_index[middle_point - i]

return index_order

0 comments on commit 3343f34

Please sign in to comment.