Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
ENH: Remove cross-validation and RMSE computation/plot from GP notebook
Browse files Browse the repository at this point in the history
Remove the cross-validation and RMSE computation and plot cells from the
GP estimation notebook.
  • Loading branch information
jhlegarreta committed Nov 9, 2024
1 parent 6dc7e89 commit bca7e8a
Showing 1 changed file with 0 additions and 191 deletions.
191 changes: 0 additions & 191 deletions docs/notebooks/dwi_gp_estimation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -452,197 +452,6 @@
"plot_dwi(shell_data[..., dwi_vol_idx], affine, gradient=np.concatenate((np.squeeze(X_test), [1000])))\n",
"plot_dwi(dwi_sim2, affine, gradient=np.concatenate((np.squeeze(X_test), [1000])), output_file=\"sherbrooke_3shell_b1k_gp_opt_pred.svg\")"
]
},
{
"cell_type": "markdown",
"id": "5c77d954",
"metadata": {},
"source": [
"I'm not sure the cross-validation effort and the rest is worth trying right now."
]
},
{
"cell_type": "markdown",
"id": "4697b2182f5178d7",
"metadata": {},
"source": [
"## Cross-validation\n",
"\n",
"Use a k-fold cross-validation and a grid search to find the best parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e92ebf1204c7f889",
"metadata": {},
"outputs": [],
"source": [
"# This is doubling-down the parameter optimization\n",
"# from sklearn.model_selection import GridSearchCV, KFold\n",
"\n",
"# Define the hyperparameter grid\n",
"# param_grid = {\n",
"# \"kernel__beta_a\": np.linspace(np.pi/4, np.pi/2, 2),\n",
"# \"kernel__beta_l\": np.linspace(0.1, 1 / 2.1, 2),\n",
"# \"alpha\": np.linspace(1e-3, 1e-2, 2)\n",
"# }\n",
"\n",
"# Define k-fold cross-validation\n",
"# cv = KFold(n_splits=5, shuffle=True, random_state=seed)\n",
"\n",
"# # Perform grid search with cross-validation\n",
"# grid_search = GridSearchCV(estimator=gpr, param_grid=param_grid, cv=cv, scoring=\"neg_mean_squared_error\")\n",
"# grid_search.fit(X_train, sampled_dwi)\n",
"\n",
"# print(f\"Best parameters found: {grid_search.best_params_}\")\n",
"# print(f\"Best cross-validation score: {-grid_search.best_score_}\")"
]
},
{
"cell_type": "markdown",
"id": "9249e76858c166e7",
"metadata": {},
"source": [
"Train the GP leaving out a randomly picked diffusion-encoding gradient direction and predict on it using the optimized parameters"
]
},
{
"cell_type": "markdown",
"id": "83166ebc22702475",
"metadata": {},
"source": [
"Define the GP instances with the optimized hiperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7e1ca2ccc899189",
"metadata": {},
"outputs": [],
"source": [
"kernel = SphericalKriging(beta_a=grid_search.best_params_[\"kernel__beta_a\"], beta_l=grid_search.best_params_[\"kernel__beta_l\"])\n",
"gpr = EddyMotionGPR(kernel=kernel, alpha=grid_search.best_params_[\"alpha\"], disp=disp, optimizer=optimizer)"
]
},
{
"cell_type": "markdown",
"id": "1ebe09c2e2fef822",
"metadata": {},
"source": [
"Pick a random diffusion-encoding gradient direction and call fit/predict"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba7e22252116b374",
"metadata": {},
"outputs": [],
"source": [
"idx = rng.integers(0, len(indices))\n",
"idx_mask = np.zeros(len(indices), dtype=bool)\n",
"idx_mask[idx] = True\n",
"\n",
"X_train = bvecs_shell[~idx_mask]\n",
"_dwi_mask = np.repeat(brain_mask[..., np.newaxis], X_train.shape[0], axis=-1)\n",
"y_train = shell_data[..., ~idx_mask][_dwi_mask].reshape((X_train.shape[0], -1))\n",
"\n",
"gpr_fit = gpr.fit(X_train, y_train)\n",
"\n",
"X_test = bvecs_shell[idx_mask]\n",
"y_pred = gpr_fit.predict(X_test)"
]
},
{
"cell_type": "markdown",
"id": "dc93122f8370d2cf",
"metadata": {},
"source": [
"Plot the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5c8e303b4639b8c",
"metadata": {},
"outputs": [],
"source": [
"x_slice = shell_data[..., idx][slice_idx[0], :, :]\n",
"y_slice = shell_data[..., idx][:, slice_idx[1], :]\n",
"z_slice = shell_data[..., idx][:, :, slice_idx[2]]\n",
"slices = [x_slice, y_slice, z_slice]\n",
"\n",
"fig, axes = plt.subplots(1, len(slices))\n",
"for i, _slice in enumerate(slices):\n",
" axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect=\"equal\")\n",
" axes[i].set_axis_off()\n",
"\n",
"plt.suptitle(\"Data\")\n",
"plt.show()\n",
"\n",
"# Reshape the predicted data array to the image shape\n",
"brain_mask_idx = np.where(brain_mask)\n",
"_y_pred = np.zeros((shell_data.shape[:-1]), dtype=y_train.dtype)\n",
"_y_pred[brain_mask_idx] = y_pred.squeeze()\n",
"\n",
"x_slice = _y_pred[slice_idx[0], :, :]\n",
"y_slice = _y_pred[:, slice_idx[1], :]\n",
"z_slice = _y_pred[:, :, slice_idx[2]]\n",
"slices = [x_slice, y_slice, z_slice]\n",
"\n",
"fig, axes = plt.subplots(1, len(slices))\n",
"for i, _slice in enumerate(slices):\n",
" axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect=\"equal\")\n",
" axes[i].set_axis_off()\n",
"\n",
"plt.suptitle(\"GP prediction\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "8e9c1fd2f4d2bfd8",
"metadata": {},
"source": [
"Compute the RMSE and plot it"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd54cb47f4e2fbf3",
"metadata": {},
"outputs": [],
"source": [
"# Multiply the data by the brain mask to remove spurious values that were not predicted\n",
"rmse = np.sqrt(np.mean(np.square(shell_data[..., idx]*brain_mask - _y_pred)))\n",
"_rmse_element = np.sqrt(np.square(shell_data[..., idx]*brain_mask - _y_pred))\n",
"\n",
"print(f\"RMSE: {rmse}\")\n",
"threshold = 10\n",
"n_error_thr = len(_rmse_element[_rmse_element > threshold])\n",
"ratio = n_error_thr / np.prod(_rmse_element.shape) * 100\n",
"print(f\"Number of RMSE values above {threshold}: {n_error_thr} ({ratio}%)\")\n",
"\n",
"# Plot the RSME\n",
"x_slice = _rmse_element[slice_idx[0], :, :]\n",
"y_slice = _rmse_element[:, slice_idx[1], :]\n",
"z_slice = _rmse_element[:, :, slice_idx[2]]\n",
"slices = [x_slice, y_slice, z_slice]\n",
"\n",
"fig, axes = plt.subplots(1, len(slices))\n",
"images = []\n",
"for i, _slice in enumerate(slices):\n",
" images.append(axes[i].imshow(_slice.T, cmap=\"viridis\", origin=\"lower\", aspect=\"equal\"))\n",
" axes[i].set_axis_off()\n",
"\n",
"plt.colorbar(images[-1])\n",
"plt.suptitle(\"RMSE\")\n",
"plt.show()"
]
}
],
"metadata": {
Expand Down

0 comments on commit bca7e8a

Please sign in to comment.