Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Oscar Esteban <[email protected]>
  • Loading branch information
esavary and oesteban authored Mar 28, 2024
1 parent 1d6762c commit bcff16d
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
import h5py


def lovo_split(data, index, with_b0=False):
def lovo_split(dataset, index, with_b0=False):
"""
Produce one fold of LOVO (leave-one-volume-out).
Parameters
----------
data : :obj:`eddymotion.data.dmri.DWI`
dataset : :obj:`eddymotion.data.dmri.DWI`
DWI object
index : :obj:`int`
Index of the DWI orientation to be left out in this fold.
Expand All @@ -48,25 +48,25 @@ def lovo_split(data, index, with_b0=False):
"""

if not Path(data.get_filename()).exists():
data.to_filename(data.get_filename())
if not Path(dataset.get_filename()).exists():
dataset.to_filename(data.get_filename())

# read original DWI data & b-vector
with h5py.File(data.get_filename(), "r") as in_file:
with h5py.File(dataset.get_filename(), "r") as in_file:
root = in_file["/0"]
dwframe = np.asanyarray(root["dataobj"][..., index])
bframe = np.asanyarray(root["gradients"][..., index])
data = np.asanyarray(root["dataobj"])
gradients = np.asanyarray(root["gradients"])

# if the size of the mask does not match data, cache is stale
mask = np.zeros(len(data), dtype=bool)
mask = np.zeros(data.shape[-1], dtype=bool)
mask[index] = True

train_data = data.dataobj[..., ~mask]
train_gradients = data.gradients[..., ~mask]
train_data = data[..., ~mask]
train_gradients = gradients[..., ~mask]

if with_b0:
train_data = np.concatenate(
(np.asanyarray(data.bzero)[..., np.newaxis], train_data),
(np.asanyarray(dataset.bzero)[..., np.newaxis], train_data),
axis=-1,
)
b0vec = np.zeros((4, 1))
Expand Down

0 comments on commit bcff16d

Please sign in to comment.