Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Fixed logo_split() call and dwdata->data
Browse files Browse the repository at this point in the history
  • Loading branch information
teresamg authored and oesteban committed Mar 27, 2024
1 parent 88bf1dc commit a3a3f1b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/eddymotion/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def lovo_split(data, index):
"""

# if the size of the mask does not match data, cache is stale
mask = np.zeros(len(dwdata), dtype=bool)
mask = np.zeros(len(data), dtype=bool)
mask[index] = True

train_data = data.dataobj[..., ~mask]
Expand Down
6 changes: 1 addition & 5 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,7 @@ def fit(
pbar.set_description_str(
f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>"
)
dwframe = np.asanyarray(dwdata.dataobj[..., i])
bframe = np.asanyarray(dwdata.gradients[..., i])
data_train, data_test = logo_split(
dwdata, dwframe, bframe, i, with_b0=True
)
data_train, data_test = logo_split(dwdata, i)

grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}"
pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")
Expand Down

0 comments on commit a3a3f1b

Please sign in to comment.