From 46ee4f5fb54d5c99a2f47a75a3e50f1ee4dcc7ad Mon Sep 17 00:00:00 2001 From: Michael Joseph Date: Thu, 2 Dec 2021 09:47:31 -0500 Subject: [PATCH] wip adding gp model --- docs/notebooks/Testing GP model.ipynb | 385 ++++++++++++++++++++++++++ eddymotion/model.py | 74 ++++- 2 files changed, 450 insertions(+), 9 deletions(-) create mode 100644 docs/notebooks/Testing GP model.ipynb diff --git a/docs/notebooks/Testing GP model.ipynb b/docs/notebooks/Testing GP model.ipynb new file mode 100644 index 00000000..6568122c --- /dev/null +++ b/docs/notebooks/Testing GP model.ipynb @@ -0,0 +1,385 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "b395dedd-ae15-4788-9911-b9050e3ff784", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from pathlib import Path\n", + "import shutil\n", + "import warnings\n", + "\n", + "from dipy.core.gradients import gradient_table\n", + "\n", + "from eddymotion import dmri\n", + "from eddymotion.viz import plot_dwi\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "750f9a01-aac3-452e-8fec-0fcc4034d448", + "metadata": {}, + "outputs": [], + "source": [ + "base_dir = Path(\"/Users/michael/projects/datasets/ds000206\")\n", + "bids_dir = base_dir / \"bids\"\n", + "derivatives_dir = base_dir / \"dmriprep\"\n", + "\n", + "dwi_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.nii.gz\"\n", + "bvec_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.bvec\"\n", + "bval_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.bval\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b2e3b427-ff19-4e5a-ace9-dd1f561f26c3", + "metadata": {}, + "outputs": [], + "source": [ + "gtab = gradient_table(str(bval_file), str(bvec_file))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e8a744ef-9822-4df4-936f-15e4cfd2b871", + "metadata": {}, + "outputs": [], + "source": [ + "#from dmriprep.interfaces.vectors import CheckGradientTable\n", + "\n", + "#gen_rasb = CheckGradientTable(dwi_file=str(dwi_file),\n", + "# in_bvec=str(bvec_file),\n", + "# in_bval=str(bval_file)\n", + "# ).run()\n", + "\n", + "rasb_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.tsv\"\n", + "#shutil.copy(\"sub-05_ses-JHU1_acq-GD72_dwi.tsv\", str(rasb_file))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0184689a-0951-4bf4-a359-fe2338654809", + "metadata": {}, + "outputs": [], + "source": [ + "b0_file = derivatives_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_desc-b0_dwi.nii.gz\"\n", + "brainmask_file = derivatives_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_desc-brain_mask.nii.gz\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5a848fc7-a0b4-4015-af55-a99ec1cbf7b0", + "metadata": {}, + "outputs": [], + "source": [ + "data = dmri.load(\n", + " str(dwi_file),\n", + " gradients_file=str(rasb_file),\n", + " b0_file=str(b0_file),\n", + " brainmask_file=str(brainmask_file)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a417e10d-a1f9-41f4-83a0-0731448ead76", + "metadata": {}, + "outputs": [], + "source": [ + "def _rasb2dipy(gradient):\n", + " gradient = np.asanyarray(gradient)\n", + " if gradient.ndim == 1:\n", + " if gradient.size != 4:\n", + " raise ValueError(\"Missing gradient information.\")\n", + " gradient = gradient[..., np.newaxis]\n", + "\n", + " if gradient.shape[0] != 4:\n", + " gradient = gradient.T\n", + " elif gradient.shape == (4, 4):\n", + " print(\"Warning: make sure gradient information is not transposed!\")\n", + "\n", + " with warnings.catch_warnings():\n", + " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", + " retval = gradient_table(gradient[3, :], gradient[:3, :].T)\n", + " return retval" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "760f45fa-d38f-4121-b9a3-9be1a5aeb6d0", + "metadata": {}, + "outputs": [], + "source": [ + "class DKIModel:\n", + " \"\"\"A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel.\"\"\"\n", + "\n", + " __slots__ = (\"_model\", \"_S0\", \"_mask\")\n", + "\n", + " def __init__(self, gtab, S0=None, mask=None, **kwargs):\n", + " \"\"\"Instantiate the wrapped tensor model.\"\"\"\n", + " from dipy.reconst.dki import DiffusionKurtosisModel\n", + "\n", + " self._S0 = None\n", + " if S0 is not None:\n", + " self._S0 = np.clip(\n", + " S0.astype(\"float32\") / S0.max(),\n", + " a_min=1e-5,\n", + " a_max=1.0,\n", + " )\n", + " self._mask = mask\n", + " if mask is None and S0 is not None:\n", + " self._mask = self._S0 > np.percentile(self._S0, 35)\n", + "\n", + " if self._mask is not None:\n", + " self._S0 = self._S0[self._mask.astype(bool)]\n", + "\n", + " kwargs = {\n", + " k: v\n", + " for k, v in kwargs.items()\n", + " if k\n", + " in (\n", + " \"min_signal\",\n", + " \"return_S0_hat\",\n", + " \"fit_method\",\n", + " \"weighting\",\n", + " \"sigma\",\n", + " \"jac\",\n", + " )\n", + " }\n", + " self._model = DiffusionKurtosisModel(gtab, **kwargs)\n", + "\n", + " def fit(self, data, **kwargs):\n", + " \"\"\"Clean-up permitted args and kwargs, and call model's fit.\"\"\"\n", + " self._model = self._model.fit(data[self._mask, ...])\n", + "\n", + " def predict(self, gradient, **kwargs):\n", + " \"\"\"Propagate model parameters and call predict.\"\"\"\n", + " predicted = np.squeeze(\n", + " self._model.predict(\n", + " _rasb2dipy(gradient),\n", + " S0=self._S0,\n", + " )\n", + " )\n", + " if predicted.ndim == 3:\n", + " return predicted\n", + "\n", + " retval = np.zeros_like(self._mask, dtype=\"float32\")\n", + " retval[self._mask, ...] = predicted\n", + " return retval" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "63cce457-29ef-4439-9677-962ab3e019a2", + "metadata": {}, + "outputs": [], + "source": [ + "class DTIModel:\n", + " \"\"\"A wrapper of :obj:`dipy.reconst.dti.TensorModel.\"\"\"\n", + "\n", + " __slots__ = (\"_model\", \"_S0\", \"_mask\")\n", + "\n", + " def __init__(self, gtab, S0=None, mask=None, **kwargs):\n", + " \"\"\"Instantiate the wrapped tensor model.\"\"\"\n", + " from dipy.reconst.dti import TensorModel as DipyTensorModel\n", + "\n", + " self._S0 = None\n", + " if S0 is not None:\n", + " self._S0 = np.clip(\n", + " S0.astype(\"float32\") / S0.max(),\n", + " a_min=1e-5,\n", + " a_max=1.0,\n", + " )\n", + "\n", + " self._mask = mask\n", + " if mask is None and S0 is not None:\n", + " self._mask = self._S0 > np.percentile(self._S0, 35)\n", + "\n", + " if self._mask is not None:\n", + " self._S0 = self._S0[self._mask.astype(bool)]\n", + "\n", + " kwargs = {\n", + " k: v\n", + " for k, v in kwargs.items()\n", + " if k\n", + " in (\n", + " \"min_signal\",\n", + " \"return_S0_hat\",\n", + " \"fit_method\",\n", + " \"weighting\",\n", + " \"sigma\",\n", + " \"jac\",\n", + " )\n", + " }\n", + " self._model = DipyTensorModel(_rasb2dipy(gtab), **kwargs)\n", + "\n", + " def fit(self, data, **kwargs):\n", + " \"\"\"Fit the model chunk-by-chunk asynchronously.\"\"\"\n", + " self._model = self._model.fit(data[self._mask, ...])\n", + "\n", + " def predict(self, gradient, **kwargs):\n", + " \"\"\"Propagate model parameters and call predict.\"\"\"\n", + " predicted = np.squeeze(\n", + " self._model.predict(\n", + " _rasb2dipy(gradient),\n", + " S0=self._S0,\n", + " )\n", + " )\n", + " if predicted.ndim == 3:\n", + " return predicted\n", + "\n", + " retval = np.zeros_like(self._mask, dtype=\"float32\")\n", + " retval[self._mask, ...] = predicted\n", + " return retval" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a0538db7-79ee-46f4-91fd-e8def6eedafe", + "metadata": {}, + "outputs": [], + "source": [ + "class SparseFascicleModel:\n", + " \"\"\"\n", + " A wrapper of :obj:`dipy.reconst.sfm.SparseFascicleModel.\n", + " \"\"\"\n", + "\n", + " __slots__ = (\"_model\", \"_S0\", \"_mask\", \"_solver\")\n", + "\n", + " def __init__(self, gtab, S0=None, mask=None, solver=None, **kwargs):\n", + " \"\"\"Instantiate the wrapped model.\"\"\"\n", + " from dipy.reconst.sfm import SparseFascicleModel\n", + " from sklearn.gaussian_process import GaussianProcessRegressor\n", + "\n", + " self._S0 = None\n", + " if S0 is not None:\n", + " self._S0 = np.clip(\n", + " S0.astype(\"float32\") / S0.max(),\n", + " a_min=1e-5,\n", + " a_max=1.0,\n", + " )\n", + "\n", + " self._mask = mask\n", + " if mask is None and S0 is not None:\n", + " self._mask = self._S0 > np.percentile(self._S0, 35)\n", + "\n", + " if self._mask is not None:\n", + " self._S0 = self._S0[self._mask.astype(bool)]\n", + "\n", + " self._solver = solver\n", + " if solver is None:\n", + " self._solver = \"ElasticNet\"\n", + "\n", + " kwargs = {k: v for k, v in kwargs.items() if k in (\"solver\",)}\n", + " self._model = SparseFascicleModel(gtab, **kwargs)\n", + "\n", + " def fit(self, data, **kwargs):\n", + " \"\"\"Clean-up permitted args and kwargs, and call model's fit.\"\"\"\n", + " self._model = self._model.fit(data[self._mask, ...])\n", + "\n", + " def predict(self, gradient, **kwargs):\n", + " \"\"\"Propagate model parameters and call predict.\"\"\"\n", + " predicted = np.squeeze(\n", + " self._model.predict(\n", + " _rasb2dipy(gradient),\n", + " S0=self._S0,\n", + " )\n", + " )\n", + " if predicted.ndim == 3:\n", + " return predicted\n", + "\n", + " retval = np.zeros_like(self._mask, dtype=\"float32\")\n", + " retval[self._mask, ...] = predicted\n", + " return retval" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "98e05ec4-00cf-42ad-b566-9b52d3afb0cd", + "metadata": {}, + "outputs": [], + "source": [ + "model = DTIModel(\n", + " gtab=data.gradients,\n", + " S0=data.bzero,\n", + " mask=data.brainmask\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46fd2487-c8f5-4c05-833d-b5f790be0819", + "metadata": {}, + "outputs": [], + "source": [ + "data_train, data_test = data.logo_split(10)\n", + "model.fit(data_train[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aede348-5726-4ba5-8d28-9cfdbc9caf4e", + "metadata": {}, + "outputs": [], + "source": [ + "predicted = model.predict(data_test[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75a77093-18f3-4388-a0a8-be3841c71d3e", + "metadata": {}, + "outputs": [], + "source": [ + "plot_dwi(predicted, data.affine, gradient=data_test[1]);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74b5228c-8220-420b-8149-26f2b58a8850", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/eddymotion/model.py b/eddymotion/model.py index b3ebe36a..6010c94b 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -25,7 +25,7 @@ def init(gtab, model="DTI", **kwargs): An array representing the gradient table in RAS+B format. model : :obj:`str` Diffusion model. - Options: ``"3DShore"``, ``"SFM"``, ``"DTI"``, ``"DKI"``, ``"S0"`` + Options: ``"3DShore"``, ``"SFM"``, ``"GP"``, ``"DTI"``, ``"DKI"``, ``"S0"`` Return ------ @@ -53,15 +53,17 @@ def init(gtab, model="DTI", **kwargs): "lambdaL": 1e-8, } - elif model.lower().startswith("sfm"): - from eddymotion.utils.model import ( - SFM4HMC as Model, - ExponentialIsotropicModel, - ) + elif model.lower() in ("sfm", "gp"): + Model = SparseFascicleModel + if model.lower() == "sfm": + param = { + "solver": "ElasticNet", + "isotropic": ExponentialIsotropicModel, + } + else: + from sklearn.gaussian_process import GaussianProcessRegressor - param = { - "isotropic": ExponentialIsotropicModel, - } + param = {"solver": GaussianProcessRegressor} elif model.lower() in ("dti", "dki"): Model = DTIModel if model.lower() == "dti" else DKIModel @@ -352,3 +354,57 @@ def _rasb2dipy(gradient): def _model_fit(model, data): return model.fit(data) + + +class SparseFascicleModel: + """ + A wrapper of :obj:`dipy.reconst.sfm.SparseFascicleModel. + """ + + __slots__ = ("_model", "_S0", "_mask", "_solver") + + def __init__(self, gtab, S0=None, mask=None, solver=None, **kwargs): + """Instantiate the wrapped model.""" + from dipy.reconst.sfm import SparseFascicleModel + from sklearn.gaussian_process import GaussianProcessRegressor + + self._S0 = None + if S0 is not None: + self._S0 = np.clip( + S0.astype("float32") / S0.max(), + a_min=1e-5, + a_max=1.0, + ) + + self._mask = mask + if mask is None and S0 is not None: + self._mask = self._S0 > np.percentile(self._S0, 35) + + if self._mask is not None: + self._S0 = self._S0[self._mask.astype(bool)] + + self._solver = solver + if solver is None: + self._solver = "ElasticNet" + + kwargs = {k: v for k, v in kwargs.items() if k in ("solver",)} + self._model = SparseFascicleModel(gtab, **kwargs) + + def fit(self, data, **kwargs): + """Clean-up permitted args and kwargs, and call model's fit.""" + self._model = self._model.fit(data[self._mask, ...]) + + def predict(self, gradient, **kwargs): + """Propagate model parameters and call predict.""" + predicted = np.squeeze( + self._model.predict( + _rasb2dipy(gradient), + S0=self._S0, + ) + ) + if predicted.ndim == 3: + return predicted + + retval = np.zeros_like(self._mask, dtype="float32") + retval[self._mask, ...] = predicted + return retval