Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
WIP: Use scikit-learn GP returned std error value
Browse files Browse the repository at this point in the history
Use scikit-learn GP returned std error value.
  • Loading branch information
jhlegarreta committed Sep 29, 2024
1 parent 62f13ba commit ea93c4a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
25 changes: 13 additions & 12 deletions scripts/dwi_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,16 @@ def perform_experiment(gtab, S0, evals1, evecs, snr, repeats, kfold):

# Simulate the fitting a number of times: every time the signal created will be a little
# different
for _ in range(repeats):
# Create the DWI signal using a single tensor
signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng)
# Fit the Gaussian process
gpfit = gp_model.fit(signal[train_mask], gtab[train_mask])
# for _ in range(repeats):
# Create the DWI signal using a single tensor
signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng)
# Fit the Gaussian process
gpfit = gp_model.fit(signal[train_mask], gtab[train_mask])

# Predict the signal
X_qry, idx_qry = get_query_vectors(gtab, train_mask)
_y_pred = gpfit.predict(X_qry)
data.append((idx_qry, signal[idx_qry], _y_pred))
# Predict the signal
X_qry, idx_qry = get_query_vectors(gtab, train_mask)
_y_pred, _y_std = gpfit.predict(X_qry)
data.append((idx_qry, signal[idx_qry], _y_pred, _y_std))

return data

Expand All @@ -186,10 +186,11 @@ def compute_error(data, repeats, kfolds):

# Loop over the range of indices that were predicted
for n in range(len(kfolds)):
repeats = 1
_data = np.array(data[n * repeats : n * repeats + repeats])
_rmse = list(map(root_mean_squared_error, _data[:, 1], _data[:, 2]))
_std_dev = np.std(_rmse)
mean_rmse.append(np.mean(_rmse))
_rmse = root_mean_squared_error(_data[0][1], _data[0][2])
_std_dev = np.mean(_data[0][3]) # np.std(_rmse)
mean_rmse.append(_rmse)
std_dev.append(_std_dev)

return np.asarray(mean_rmse), np.asarray(std_dev)
Expand Down
2 changes: 1 addition & 1 deletion src/eddymotion/model/_dipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def gp_prediction(
raise RuntimeError("Model is not yet fitted.")

# Extract orientations from gtab, and highly likely, the b-value too.
return model.predict(gtab, return_std=False)
return model.predict(gtab, return_std=True)

Check warning on line 141 in src/eddymotion/model/_dipy.py

View check run for this annotation

Codecov / codecov/patch

src/eddymotion/model/_dipy.py#L141

Added line #L141 was not covered by tests


class GaussianProcessModel(ReconstModel):
Expand Down

0 comments on commit ea93c4a

Please sign in to comment.