Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

ENH: Indexes sorting options #150

Merged
merged 47 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3343f34
enh: externalize sorting functions
esavary Apr 4, 2024
1e305d9
fix:import statement
esavary Apr 4, 2024
6a8a585
sty: fix style error
esavary Apr 4, 2024
4e199a9
sty: fix style error
esavary Apr 4, 2024
b13c975
sty: fix style errors
esavary Apr 4, 2024
148bb71
Apply suggestions from code review
esavary Apr 5, 2024
5155f9f
fix: revert stylistic changes
esavary Apr 5, 2024
9420ce5
fix: revert stylistic changes
esavary Apr 5, 2024
3d7afb1
fix: revert stylistic changes
esavary Apr 5, 2024
69611fa
Update src/eddymotion/utils.py
esavary Apr 5, 2024
31fa9d6
Update src/eddymotion/utils.py
esavary Apr 5, 2024
6320f66
Update src/eddymotion/utils.py
esavary Apr 5, 2024
4998afa
fix: revert changes in estimator.py
esavary Apr 5, 2024
a424582
Apply suggestions from code review
esavary Apr 5, 2024
88bea1d
fix: typos
esavary Apr 5, 2024
b58a708
enh: update args and test for bvalue_action
esavary Apr 5, 2024
4366683
fix: docstring
esavary Apr 5, 2024
ee95877
fix: remove unused import
esavary Apr 5, 2024
e84bddd
Apply suggestions from code review
esavary Apr 5, 2024
3c0b36d
fix: add link for new module documentation
esavary Apr 5, 2024
22bbb68
fix: add exeption
esavary Apr 5, 2024
43fe703
fix: bvalue_action implementation + typos
esavary Apr 5, 2024
11f52a4
sty: remove white space
esavary Apr 5, 2024
cdfe20a
fix: typos
esavary Apr 5, 2024
f858a6a
sty: remove white space
esavary Apr 5, 2024
90cf31c
Apply suggestions from code review
esavary Apr 8, 2024
fa43376
sty: change iterator names
esavary Apr 8, 2024
cdeb34f
sty: fix docstring
esavary Apr 8, 2024
965b3c5
Apply suggestions from code review
esavary Apr 8, 2024
a43fdde
fix: random seed and add test
esavary Apr 8, 2024
ae19816
fix: typos
esavary Apr 8, 2024
d78c500
enh: generalize docstring
esavary Apr 8, 2024
6aaa7d4
Update src/eddymotion/utils.py
esavary Apr 8, 2024
4e7a418
Apply suggestions from code review
esavary Apr 8, 2024
4163b0b
Apply suggestions from code review
esavary Apr 8, 2024
ed20ca8
fix: random generator test results
esavary Apr 8, 2024
d966c97
Apply suggestions from code review
esavary Apr 8, 2024
0057079
sty: add typing
esavary Apr 8, 2024
6f0d5fe
fix: remove unused import
esavary Apr 8, 2024
2cc923a
Apply suggestions from code review
esavary Apr 8, 2024
9bda539
fix: import position
esavary Apr 8, 2024
d1c9846
Apply suggestions from code review
esavary Apr 8, 2024
807a468
fix: iterator typing in docstring
esavary Apr 8, 2024
fee865d
sty: fix ruff errors
esavary Apr 8, 2024
5f0837d
Merge branch 'nipreps:main' into index-sorting
esavary Apr 8, 2024
caa72a5
sty: fix sty errors
esavary Apr 8, 2024
911a730
sty: fix sty errors
esavary Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
# https://www.nipreps.org/community/licensing/
#
"""Data splitting helpers."""

from pathlib import Path
import numpy as np

import h5py
import numpy as np


def lovo_split(dataset, index, with_b0=False):
Expand Down
39 changes: 2 additions & 37 deletions src/eddymotion/estimator.py
oesteban marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from eddymotion.data.splitting import lovo_split
from eddymotion.model import ModelFactory
from eddymotion.utils import SortingStrategy, sort_dwdata_indices


class EddyMotionEstimator:
Expand Down Expand Up @@ -88,22 +89,15 @@ def fit(

align_kwargs = align_kwargs or {}

index_order = _sort_dwdata_indices(seed, len(dwdata))
index_order = sort_dwdata_indices(dwdata, SortingStrategy.RANDOM, seed=None)

if "num_threads" not in align_kwargs and omp_nthreads is not None:
align_kwargs["num_threads"] = omp_nthreads

n_iter = len(models)
for i_iter, model in enumerate(models):
reg_target_type = (
"dwi" if model.lower() not in ("b0", "s0", "avg", "average", "mean") else "b0"
)

# When downsampling these need to be set per-level
bmask_img = _prepare_brainmask_data(dwdata.brainmask, dwdata.affine)

_prepare_kwargs(dwdata, kwargs)

single_model = model.lower() in (
"b0",
"s0",
Expand Down Expand Up @@ -241,35 +235,6 @@ def _to_nifti(data, affine, filename, clip=True):
nii.to_filename(filename)


def _sort_dwdata_indices(seed, dwi_vol_count):
"""Sort the DWI data volume indices.

Parameters
----------
seed : :obj:`int` or :obj:`bool`
Seed the random number generator. If an integer, the value is used to initialize the
generator; if ``True``, the arbitrary value of ``20210324`` is used to initialize it.
dwi_vol_count : :obj:`int`
Number of DWI volumes.

Returns
-------
index_order : :obj:`numpy.ndarray`
Index order.
"""

_seed = None
if seed or seed == 0:
_seed = 20210324 if seed is True else seed

rng = np.random.default_rng(_seed)

index_order = np.arange(dwi_vol_count)
rng.shuffle(index_order)

return index_order


def _prepare_brainmask_data(brainmask, affine):
"""Prepare the brainmask data: save the data to disk.

Expand Down
176 changes: 176 additions & 0 deletions src/eddymotion/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2022 The NiPreps Developers <[email protected]>
esavary marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Utils to sort the DWI data volume indices"""
esavary marked this conversation as resolved.
Show resolved Hide resolved

from enum import Enum

import numpy as np
esavary marked this conversation as resolved.
Show resolved Hide resolved
esavary marked this conversation as resolved.
Show resolved Hide resolved


oesteban marked this conversation as resolved.
Show resolved Hide resolved
class SortingStrategy(Enum):
"""
Enum class representing different sorting strategies.

Available sorting strategies:
- LINEAR: Sorts the items in a linear order.
- RANDOM: Sorts the items in a random order.
- BVALUE: Sorts the items based on their b-value.
- CENTRALSYM: Sorts the items based on their central symmetry.
"""

LINEAR = "linear"
RANDOM = "random"
BVALUE = "bvalue"
CENTRALSYM = "centralsym"


def sort_dwdata_indices(dwdata, strategy, seed=None):
"""Sort the DWI data volume indices following the given strategy.

Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI`
DWI dataset, represented by this tool's internal type.
strategy : :obj:`~eddymotion.utils.SortingStrategy`
The sorting strategy to be used. Available options are:
- SortingStrategy.LINEAR: Sort the indices linearly.
- SortingStrategy.RANDOM: Sort the indices randomly.
- SortingStrategy.BVALUE: Sort the indices based on the last column of gradients in ascending order.
- SortingStrategy.CENTRALSYM: Sort the indices in a central symmetric manner.
seed : :obj:`int` or :obj:`bool`, optional
Seed the random number generator. If an integer, the value is used to
initialize the generator; if ``True``, the arbitrary value
of ``20210324`` is used to initialize it.

Returns
-------
index_order : :obj:`numpy.ndarray`
The sorted index order.
"""
if strategy == SortingStrategy.LINEAR:
return linear_action(dwdata)
elif strategy == SortingStrategy.RANDOM:
return random_action(dwdata, seed)
elif strategy == SortingStrategy.BVALUE:
return bvalue_action(dwdata)
elif strategy == SortingStrategy.CENTRALSYM:
return centralsym_action(dwdata)
else:
raise ValueError("Invalid sorting strategy")

esavary marked this conversation as resolved.
Show resolved Hide resolved

def linear_action(dwdata):
esavary marked this conversation as resolved.
Show resolved Hide resolved
"""
Sort the DWI data volume indices linearly

Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI` DWI dataset, represented by this tool's internal type.

Returns
-------
index_order : :obj:`numpy.ndarray`
The sorted index order.
esavary marked this conversation as resolved.
Show resolved Hide resolved
"""
index_order = np.arange(len(dwdata))

return index_order


def random_action(dwdata, seed=None):
esavary marked this conversation as resolved.
Show resolved Hide resolved
"""Sort the DWI data volume indices.
esavary marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI` DWI dataset, represented by this tool's internal type.
seed : :obj:`int` or :obj:`bool`, optional
Seed the random number generator. If an integer, the value is used to
initialize the generator; if ``True``, the arbitrary value
of ``20210324`` is used to initialize it.

Returns
-------
index_order : :obj:`numpy.ndarray`
The sorted index order.
oesteban marked this conversation as resolved.
Show resolved Hide resolved
"""

_seed = None
esavary marked this conversation as resolved.
Show resolved Hide resolved
if seed or seed == 0:
_seed = 20210324 if seed is True else seed

rng = np.random.default_rng(_seed)
esavary marked this conversation as resolved.
Show resolved Hide resolved

index_order = np.arange(len(dwdata))
rng.shuffle(index_order)

return index_order
esavary marked this conversation as resolved.
Show resolved Hide resolved


def bvalue_action(dwdata):
"""
Sort the DWI data volume indices in ascending order based on the last column of gradients.
esavary marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI` DWI dataset, represented by this tool's internal type.

Returns
-------
numpy.ndarray: The sorted index order.
"""
last_column = dwdata.gradients[:, -1]
index_order = np.argsort(last_column)
return index_order


def centralsym_action(dwdata):
esavary marked this conversation as resolved.
Show resolved Hide resolved
"""
Sort the DWI data volume indices in a central symmetric manner.
esavary marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
dwdata : :obj:`~eddymotion.dmri.DWI` DWI dataset, represented by this tool's internal type.

Returns
-------
numpy.ndarray: The sorted index order.

esavary marked this conversation as resolved.
Show resolved Hide resolved
esavary marked this conversation as resolved.
Show resolved Hide resolved
"""
old_index = np.arange(len(dwdata))

index_order = old_index.copy()
if len(old_index) % 2 == 0:
middle_point = int(len(old_index) / 2 - 1)
index_order[0] = old_index[middle_point]

for i in np.arange(1, middle_point+1):
index_order[2 * i - 1] = old_index[middle_point + i]
index_order[2 * i] = old_index[middle_point - i]
else:
middle_point = int(len(old_index) / 2)
index_order[0] = old_index[middle_point]
for i in np.arange(1, middle_point + 1):
index_order[2 * i - 1] = old_index[middle_point + i]
index_order[2 * i] = old_index[middle_point - i]

return index_order
esavary marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion test/test_model.py
oesteban marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import pytest

from eddymotion import model
from eddymotion.data.splitting import lovo_split
from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split


def test_trivial_model():
Expand Down
6 changes: 3 additions & 3 deletions test/test_splitting.py
oesteban marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
# https://www.nipreps.org/community/licensing/
#
"""Unit test testing the lovo_split function."""

import numpy as np

from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split

Expand Down Expand Up @@ -50,13 +52,11 @@ def test_lovo_split(datadir):
data.gradients[..., index] = 1

# Apply the lovo_split function at the specified index
(train_data, train_gradients), \
(test_data, test_gradients) = lovo_split(data, index)
(train_data, train_gradients), (test_data, test_gradients) = lovo_split(data, index)

# Check if the test data contains only 1s
# and the train data contains only 0s after the split
assert np.all(test_data == 1)
assert np.all(train_data == 0)
assert np.all(test_gradients == 1)
assert np.all(train_gradients == 0)

Loading