From a3a3f1b719bdc3c1986edf1f6f18ade57393f9ea Mon Sep 17 00:00:00 2001 From: Teresa Gomez <46339554+teresamg@users.noreply.github.com> Date: Thu, 15 Dec 2022 14:28:39 -0800 Subject: [PATCH] Fixed logo_split() call and dwdata->data --- src/eddymotion/data/splitting.py | 2 +- src/eddymotion/estimator.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/eddymotion/data/splitting.py b/src/eddymotion/data/splitting.py index d6229575..aec47059 100644 --- a/src/eddymotion/data/splitting.py +++ b/src/eddymotion/data/splitting.py @@ -47,7 +47,7 @@ def lovo_split(data, index): """ # if the size of the mask does not match data, cache is stale - mask = np.zeros(len(dwdata), dtype=bool) + mask = np.zeros(len(data), dtype=bool) mask[index] = True train_data = data.dataobj[..., ~mask] diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index ad49454e..b864c266 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -151,11 +151,7 @@ def fit( pbar.set_description_str( f"Pass {i_iter + 1}/{n_iter} | Fit and predict b-index <{i}>" ) - dwframe = np.asanyarray(dwdata.dataobj[..., i]) - bframe = np.asanyarray(dwdata.gradients[..., i]) - data_train, data_test = logo_split( - dwdata, dwframe, bframe, i, with_b0=True - ) + data_train, data_test = logo_split(dwdata, i) grad_str = f"{i}, {data_test[1][:3]}, b={int(data_test[1][3])}" pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs")