Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #154 from esavary/style-correction
Browse files Browse the repository at this point in the history
STY: fix style errors
  • Loading branch information
oesteban authored Apr 8, 2024
2 parents a1bcacd + 0c36275 commit ecd7ce8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
# https://www.nipreps.org/community/licensing/
#
"""Data splitting helpers."""

from pathlib import Path
import numpy as np

import h5py
import numpy as np


def lovo_split(dataset, index, with_b0=False):
Expand Down
2 changes: 1 addition & 1 deletion test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import pytest

from eddymotion import model
from eddymotion.data.splitting import lovo_split
from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split


def test_trivial_model():
Expand Down
6 changes: 3 additions & 3 deletions test/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
# https://www.nipreps.org/community/licensing/
#
"""Unit test testing the lovo_split function."""

import numpy as np

from eddymotion.data.dmri import DWI
from eddymotion.data.splitting import lovo_split

Expand Down Expand Up @@ -50,13 +52,11 @@ def test_lovo_split(datadir):
data.gradients[..., index] = 1

# Apply the lovo_split function at the specified index
(train_data, train_gradients), \
(test_data, test_gradients) = lovo_split(data, index)
(train_data, train_gradients), (test_data, test_gradients) = lovo_split(data, index)

# Check if the test data contains only 1s
# and the train data contains only 0s after the split
assert np.all(test_data == 1)
assert np.all(train_data == 0)
assert np.all(test_gradients == 1)
assert np.all(train_gradients == 0)

0 comments on commit ecd7ce8

Please sign in to comment.