Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
WIP: Use scikit-learn k-folds
Browse files Browse the repository at this point in the history
Use `scikit-learn` k-folds.
  • Loading branch information
jhlegarreta committed Oct 6, 2024
1 parent 3b326f7 commit acf683b
Showing 1 changed file with 53 additions and 24 deletions.
77 changes: 53 additions & 24 deletions scripts/dwi_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from dipy.sims.voxel import all_tensor_evecs, single_tensor
from matplotlib import pyplot as plt
from sklearn.metrics import root_mean_squared_error
from sklearn.model_selection import KFold

from eddymotion.model._dipy import GaussianProcessModel

Expand Down Expand Up @@ -150,30 +151,47 @@ def perform_experiment(gtab, S0, evals1, evecs, snr, repeats, kfold):
a = 1.0
sigma_sq = 0.5

data = []
data = {}

nzero_bvecs = gtab.bvecs[~gtab.b0s_mask]

# Simulate the fitting a number of times: every time the signal created will be a little
# different
# for _ in range(repeats):
# Create the DWI signal using a single tensor
signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng)

# Loop over the number of indices that are left out from the training/need to be predicted
for n in kfold:
# Assumptions:
# - Consecutive indices in the folds
# - A single b0
kf = KFold(n_splits=n, shuffle=False)

# Define the Gaussian process model instance
gp_model = GaussianProcessModel(
kernel_model=kernel_model, lambda_s=lambda_s, a=a, sigma_sq=sigma_sq
)

# Create the training mask leaving out the requested number of samples
train_mask = create_random_train_mask(gtab, n)
_data = []

for _, (train_index, test_index) in enumerate(kf.split(nzero_bvecs)):
# Create the training mask leaving out the requested number of samples
# train_mask = create_random_train_mask(gtab, n)

# Fit the Gaussian process
# Add 1 to account for the b0
gpfit = gp_model.fit(signal[train_index + 1], gtab[train_index + 1])

# Simulate the fitting a number of times: every time the signal created will be a little
# different
# for _ in range(repeats):
# Create the DWI signal using a single tensor
signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng)
# Fit the Gaussian process
gpfit = gp_model.fit(signal[train_mask], gtab[train_mask])
# Predict the signal
# X_qry, idx_qry = get_query_vectors(gtab, train_mask)
# Add 1 to account for the b0
idx_qry = test_index + 1
X_qry = gtab[idx_qry].bvecs
_y_pred, _y_std = gpfit.predict(X_qry)
_data.append((idx_qry, signal[idx_qry], _y_pred, _y_std))

# Predict the signal
X_qry, idx_qry = get_query_vectors(gtab, train_mask)
_y_pred, _y_std = gpfit.predict(X_qry)
data.append((idx_qry, signal[idx_qry], _y_pred, _y_std))
data.update({n: _data})

return data

Expand All @@ -185,18 +203,26 @@ def compute_error(data, repeats, kfolds):
std_dev = []

# Loop over the range of indices that were predicted
for n in range(len(kfolds)):
repeats = 1
_data = np.array(data[n * repeats : n * repeats + repeats])
_rmse = root_mean_squared_error(_data[0][1], _data[0][2])
_std_dev = np.mean(_data[0][3]) # np.std(_rmse)
for vals in data.values():
# repeats = 1
# _data = np.array(vals[n * repeats : n * repeats + repeats])
_signal = np.hstack([t[1] for t in vals])
_pred = np.hstack([t[2] for t in vals])
_rmse = root_mean_squared_error(_signal, _pred)
# ToDo
# Check here what is the value that we wil keep for the std
_std = np.hstack([t[3] for t in vals])
_std_dev = np.mean(_std)
_std_dev = np.std(
[root_mean_squared_error([v1], [v2]) for v1, v2 in zip(_signal, _pred, strict=False)]
) # np.std(_rmse)
mean_rmse.append(_rmse)
std_dev.append(_std_dev)

return np.asarray(mean_rmse), np.asarray(std_dev)


def plot_error(kfolds, mean, std_dev):
def plot_error(kfolds, mean, std_dev, xlabel, ylabel, title):
"""Plot the error and standard deviation."""

fig, ax = plt.subplots()
Expand All @@ -209,11 +235,11 @@ def plot_error(kfolds, mean, std_dev):
color="orange",
)
ax.scatter(kfolds, mean, c="orange")
ax.set_xlabel("N")
ax.set_ylabel("RMSE")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xticks(kfolds)
ax.set_xticklabels(kfolds)
ax.set_title("Gaussian process estimation")
ax.set_title(title)
fig.tight_layout()

return fig
Expand Down Expand Up @@ -294,7 +320,10 @@ def main():
rmse, std_dev = compute_error(data, args.repeats, args.kfold)

# Plot
_ = plot_error(args.kfold, rmse, std_dev)
xlabel = "N"
ylabel = "RMSE"
title = f"Gaussian process estimation\n(SNR={args.snr})"
_ = plot_error(args.kfold, rmse, std_dev, xlabel, ylabel, title)
# fig = plot_error(args.kfold, rmse, std_dev)
# fig.save(args.gp_pred_plot_error_fname, format="svg")

Expand Down

0 comments on commit acf683b

Please sign in to comment.