From acf683baa8b5b34a4fee1194bc0de7d4949edb11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sun, 6 Oct 2024 18:09:39 -0400 Subject: [PATCH] WIP: Use `scikit-learn` k-folds Use `scikit-learn` k-folds. --- scripts/dwi_estimation_error_analysis.py | 77 ++++++++++++++++-------- 1 file changed, 53 insertions(+), 24 deletions(-) diff --git a/scripts/dwi_estimation_error_analysis.py b/scripts/dwi_estimation_error_analysis.py index 8b74dfea..7f1aa717 100644 --- a/scripts/dwi_estimation_error_analysis.py +++ b/scripts/dwi_estimation_error_analysis.py @@ -35,6 +35,7 @@ from dipy.sims.voxel import all_tensor_evecs, single_tensor from matplotlib import pyplot as plt from sklearn.metrics import root_mean_squared_error +from sklearn.model_selection import KFold from eddymotion.model._dipy import GaussianProcessModel @@ -150,30 +151,47 @@ def perform_experiment(gtab, S0, evals1, evecs, snr, repeats, kfold): a = 1.0 sigma_sq = 0.5 - data = [] + data = {} + + nzero_bvecs = gtab.bvecs[~gtab.b0s_mask] + + # Simulate the fitting a number of times: every time the signal created will be a little + # different + # for _ in range(repeats): + # Create the DWI signal using a single tensor + signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng) # Loop over the number of indices that are left out from the training/need to be predicted for n in kfold: + # Assumptions: + # - Consecutive indices in the folds + # - A single b0 + kf = KFold(n_splits=n, shuffle=False) + # Define the Gaussian process model instance gp_model = GaussianProcessModel( kernel_model=kernel_model, lambda_s=lambda_s, a=a, sigma_sq=sigma_sq ) - # Create the training mask leaving out the requested number of samples - train_mask = create_random_train_mask(gtab, n) + _data = [] + + for _, (train_index, test_index) in enumerate(kf.split(nzero_bvecs)): + # Create the training mask leaving out the requested number of samples + # train_mask = create_random_train_mask(gtab, n) + + # Fit the Gaussian process + # Add 1 to account for the b0 + gpfit = gp_model.fit(signal[train_index + 1], gtab[train_index + 1]) - # Simulate the fitting a number of times: every time the signal created will be a little - # different - # for _ in range(repeats): - # Create the DWI signal using a single tensor - signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng) - # Fit the Gaussian process - gpfit = gp_model.fit(signal[train_mask], gtab[train_mask]) + # Predict the signal + # X_qry, idx_qry = get_query_vectors(gtab, train_mask) + # Add 1 to account for the b0 + idx_qry = test_index + 1 + X_qry = gtab[idx_qry].bvecs + _y_pred, _y_std = gpfit.predict(X_qry) + _data.append((idx_qry, signal[idx_qry], _y_pred, _y_std)) - # Predict the signal - X_qry, idx_qry = get_query_vectors(gtab, train_mask) - _y_pred, _y_std = gpfit.predict(X_qry) - data.append((idx_qry, signal[idx_qry], _y_pred, _y_std)) + data.update({n: _data}) return data @@ -185,18 +203,26 @@ def compute_error(data, repeats, kfolds): std_dev = [] # Loop over the range of indices that were predicted - for n in range(len(kfolds)): - repeats = 1 - _data = np.array(data[n * repeats : n * repeats + repeats]) - _rmse = root_mean_squared_error(_data[0][1], _data[0][2]) - _std_dev = np.mean(_data[0][3]) # np.std(_rmse) + for vals in data.values(): + # repeats = 1 + # _data = np.array(vals[n * repeats : n * repeats + repeats]) + _signal = np.hstack([t[1] for t in vals]) + _pred = np.hstack([t[2] for t in vals]) + _rmse = root_mean_squared_error(_signal, _pred) + # ToDo + # Check here what is the value that we wil keep for the std + _std = np.hstack([t[3] for t in vals]) + _std_dev = np.mean(_std) + _std_dev = np.std( + [root_mean_squared_error([v1], [v2]) for v1, v2 in zip(_signal, _pred, strict=False)] + ) # np.std(_rmse) mean_rmse.append(_rmse) std_dev.append(_std_dev) return np.asarray(mean_rmse), np.asarray(std_dev) -def plot_error(kfolds, mean, std_dev): +def plot_error(kfolds, mean, std_dev, xlabel, ylabel, title): """Plot the error and standard deviation.""" fig, ax = plt.subplots() @@ -209,11 +235,11 @@ def plot_error(kfolds, mean, std_dev): color="orange", ) ax.scatter(kfolds, mean, c="orange") - ax.set_xlabel("N") - ax.set_ylabel("RMSE") + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) ax.set_xticks(kfolds) ax.set_xticklabels(kfolds) - ax.set_title("Gaussian process estimation") + ax.set_title(title) fig.tight_layout() return fig @@ -294,7 +320,10 @@ def main(): rmse, std_dev = compute_error(data, args.repeats, args.kfold) # Plot - _ = plot_error(args.kfold, rmse, std_dev) + xlabel = "N" + ylabel = "RMSE" + title = f"Gaussian process estimation\n(SNR={args.snr})" + _ = plot_error(args.kfold, rmse, std_dev, xlabel, ylabel, title) # fig = plot_error(args.kfold, rmse, std_dev) # fig.save(args.gp_pred_plot_error_fname, format="svg")