diff --git a/docs/notebooks/dwi_gp_estimation.ipynb b/docs/notebooks/dwi_gp_estimation.ipynb index b4e01e00..2cce4b3e 100644 --- a/docs/notebooks/dwi_gp_estimation.ipynb +++ b/docs/notebooks/dwi_gp_estimation.ipynb @@ -452,197 +452,6 @@ "plot_dwi(shell_data[..., dwi_vol_idx], affine, gradient=np.concatenate((np.squeeze(X_test), [1000])))\n", "plot_dwi(dwi_sim2, affine, gradient=np.concatenate((np.squeeze(X_test), [1000])), output_file=\"sherbrooke_3shell_b1k_gp_opt_pred.svg\")" ] - }, - { - "cell_type": "markdown", - "id": "5c77d954", - "metadata": {}, - "source": [ - "I'm not sure the cross-validation effort and the rest is worth trying right now." - ] - }, - { - "cell_type": "markdown", - "id": "4697b2182f5178d7", - "metadata": {}, - "source": [ - "## Cross-validation\n", - "\n", - "Use a k-fold cross-validation and a grid search to find the best parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e92ebf1204c7f889", - "metadata": {}, - "outputs": [], - "source": [ - "# This is doubling-down the parameter optimization\n", - "# from sklearn.model_selection import GridSearchCV, KFold\n", - "\n", - "# Define the hyperparameter grid\n", - "# param_grid = {\n", - "# \"kernel__beta_a\": np.linspace(np.pi/4, np.pi/2, 2),\n", - "# \"kernel__beta_l\": np.linspace(0.1, 1 / 2.1, 2),\n", - "# \"alpha\": np.linspace(1e-3, 1e-2, 2)\n", - "# }\n", - "\n", - "# Define k-fold cross-validation\n", - "# cv = KFold(n_splits=5, shuffle=True, random_state=seed)\n", - "\n", - "# # Perform grid search with cross-validation\n", - "# grid_search = GridSearchCV(estimator=gpr, param_grid=param_grid, cv=cv, scoring=\"neg_mean_squared_error\")\n", - "# grid_search.fit(X_train, sampled_dwi)\n", - "\n", - "# print(f\"Best parameters found: {grid_search.best_params_}\")\n", - "# print(f\"Best cross-validation score: {-grid_search.best_score_}\")" - ] - }, - { - "cell_type": "markdown", - "id": "9249e76858c166e7", - "metadata": {}, - "source": [ - "Train the GP leaving out a randomly picked diffusion-encoding gradient direction and predict on it using the optimized parameters" - ] - }, - { - "cell_type": "markdown", - "id": "83166ebc22702475", - "metadata": {}, - "source": [ - "Define the GP instances with the optimized hiperparameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7e1ca2ccc899189", - "metadata": {}, - "outputs": [], - "source": [ - "kernel = SphericalKriging(beta_a=grid_search.best_params_[\"kernel__beta_a\"], beta_l=grid_search.best_params_[\"kernel__beta_l\"])\n", - "gpr = EddyMotionGPR(kernel=kernel, alpha=grid_search.best_params_[\"alpha\"], disp=disp, optimizer=optimizer)" - ] - }, - { - "cell_type": "markdown", - "id": "1ebe09c2e2fef822", - "metadata": {}, - "source": [ - "Pick a random diffusion-encoding gradient direction and call fit/predict" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba7e22252116b374", - "metadata": {}, - "outputs": [], - "source": [ - "idx = rng.integers(0, len(indices))\n", - "idx_mask = np.zeros(len(indices), dtype=bool)\n", - "idx_mask[idx] = True\n", - "\n", - "X_train = bvecs_shell[~idx_mask]\n", - "_dwi_mask = np.repeat(brain_mask[..., np.newaxis], X_train.shape[0], axis=-1)\n", - "y_train = shell_data[..., ~idx_mask][_dwi_mask].reshape((X_train.shape[0], -1))\n", - "\n", - "gpr_fit = gpr.fit(X_train, y_train)\n", - "\n", - "X_test = bvecs_shell[idx_mask]\n", - "y_pred = gpr_fit.predict(X_test)" - ] - }, - { - "cell_type": "markdown", - "id": "dc93122f8370d2cf", - "metadata": {}, - "source": [ - "Plot the data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e5c8e303b4639b8c", - "metadata": {}, - "outputs": [], - "source": [ - "x_slice = shell_data[..., idx][slice_idx[0], :, :]\n", - "y_slice = shell_data[..., idx][:, slice_idx[1], :]\n", - "z_slice = shell_data[..., idx][:, :, slice_idx[2]]\n", - "slices = [x_slice, y_slice, z_slice]\n", - "\n", - "fig, axes = plt.subplots(1, len(slices))\n", - "for i, _slice in enumerate(slices):\n", - " axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect=\"equal\")\n", - " axes[i].set_axis_off()\n", - "\n", - "plt.suptitle(\"Data\")\n", - "plt.show()\n", - "\n", - "# Reshape the predicted data array to the image shape\n", - "brain_mask_idx = np.where(brain_mask)\n", - "_y_pred = np.zeros((shell_data.shape[:-1]), dtype=y_train.dtype)\n", - "_y_pred[brain_mask_idx] = y_pred.squeeze()\n", - "\n", - "x_slice = _y_pred[slice_idx[0], :, :]\n", - "y_slice = _y_pred[:, slice_idx[1], :]\n", - "z_slice = _y_pred[:, :, slice_idx[2]]\n", - "slices = [x_slice, y_slice, z_slice]\n", - "\n", - "fig, axes = plt.subplots(1, len(slices))\n", - "for i, _slice in enumerate(slices):\n", - " axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect=\"equal\")\n", - " axes[i].set_axis_off()\n", - "\n", - "plt.suptitle(\"GP prediction\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "8e9c1fd2f4d2bfd8", - "metadata": {}, - "source": [ - "Compute the RMSE and plot it" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd54cb47f4e2fbf3", - "metadata": {}, - "outputs": [], - "source": [ - "# Multiply the data by the brain mask to remove spurious values that were not predicted\n", - "rmse = np.sqrt(np.mean(np.square(shell_data[..., idx]*brain_mask - _y_pred)))\n", - "_rmse_element = np.sqrt(np.square(shell_data[..., idx]*brain_mask - _y_pred))\n", - "\n", - "print(f\"RMSE: {rmse}\")\n", - "threshold = 10\n", - "n_error_thr = len(_rmse_element[_rmse_element > threshold])\n", - "ratio = n_error_thr / np.prod(_rmse_element.shape) * 100\n", - "print(f\"Number of RMSE values above {threshold}: {n_error_thr} ({ratio}%)\")\n", - "\n", - "# Plot the RSME\n", - "x_slice = _rmse_element[slice_idx[0], :, :]\n", - "y_slice = _rmse_element[:, slice_idx[1], :]\n", - "z_slice = _rmse_element[:, :, slice_idx[2]]\n", - "slices = [x_slice, y_slice, z_slice]\n", - "\n", - "fig, axes = plt.subplots(1, len(slices))\n", - "images = []\n", - "for i, _slice in enumerate(slices):\n", - " images.append(axes[i].imshow(_slice.T, cmap=\"viridis\", origin=\"lower\", aspect=\"equal\"))\n", - " axes[i].set_axis_off()\n", - "\n", - "plt.colorbar(images[-1])\n", - "plt.suptitle(\"RMSE\")\n", - "plt.show()" - ] } ], "metadata": {