Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Fix: return test data and gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
esavary committed Mar 28, 2024
1 parent bcff16d commit 316eb57
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def lovo_split(dataset, index, with_b0=False):
mask[index] = True

train_data = data[..., ~mask]
train_gradients = gradients[..., ~mask]
train_gradients = gradients[..., mask]
test_data = data[..., ~mask]
test_gradients = gradients[..., mask]

if with_b0:
train_data = np.concatenate(
Expand All @@ -78,5 +80,5 @@ def lovo_split(dataset, index, with_b0=False):

return (
(train_data, train_gradients),
(dwframe, bframe),
(test_data, test_gradients),
)

0 comments on commit 316eb57

Please sign in to comment.