Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
wip adding gp model
Browse files Browse the repository at this point in the history
  • Loading branch information
josephmje committed Dec 2, 2021
1 parent 5f17862 commit 46ee4f5
Show file tree
Hide file tree
Showing 2 changed files with 450 additions and 9 deletions.
385 changes: 385 additions & 0 deletions docs/notebooks/Testing GP model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,385 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "b395dedd-ae15-4788-9911-b9050e3ff784",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pathlib import Path\n",
"import shutil\n",
"import warnings\n",
"\n",
"from dipy.core.gradients import gradient_table\n",
"\n",
"from eddymotion import dmri\n",
"from eddymotion.viz import plot_dwi\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "750f9a01-aac3-452e-8fec-0fcc4034d448",
"metadata": {},
"outputs": [],
"source": [
"base_dir = Path(\"/Users/michael/projects/datasets/ds000206\")\n",
"bids_dir = base_dir / \"bids\"\n",
"derivatives_dir = base_dir / \"dmriprep\"\n",
"\n",
"dwi_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.nii.gz\"\n",
"bvec_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.bvec\"\n",
"bval_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.bval\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b2e3b427-ff19-4e5a-ace9-dd1f561f26c3",
"metadata": {},
"outputs": [],
"source": [
"gtab = gradient_table(str(bval_file), str(bvec_file))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e8a744ef-9822-4df4-936f-15e4cfd2b871",
"metadata": {},
"outputs": [],
"source": [
"#from dmriprep.interfaces.vectors import CheckGradientTable\n",
"\n",
"#gen_rasb = CheckGradientTable(dwi_file=str(dwi_file),\n",
"# in_bvec=str(bvec_file),\n",
"# in_bval=str(bval_file)\n",
"# ).run()\n",
"\n",
"rasb_file = bids_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_dwi.tsv\"\n",
"#shutil.copy(\"sub-05_ses-JHU1_acq-GD72_dwi.tsv\", str(rasb_file))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0184689a-0951-4bf4-a359-fe2338654809",
"metadata": {},
"outputs": [],
"source": [
"b0_file = derivatives_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_desc-b0_dwi.nii.gz\"\n",
"brainmask_file = derivatives_dir / \"sub-05\" / \"ses-JHU1\" / \"dwi\" / \"sub-05_ses-JHU1_acq-GD72_desc-brain_mask.nii.gz\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5a848fc7-a0b4-4015-af55-a99ec1cbf7b0",
"metadata": {},
"outputs": [],
"source": [
"data = dmri.load(\n",
" str(dwi_file),\n",
" gradients_file=str(rasb_file),\n",
" b0_file=str(b0_file),\n",
" brainmask_file=str(brainmask_file)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a417e10d-a1f9-41f4-83a0-0731448ead76",
"metadata": {},
"outputs": [],
"source": [
"def _rasb2dipy(gradient):\n",
" gradient = np.asanyarray(gradient)\n",
" if gradient.ndim == 1:\n",
" if gradient.size != 4:\n",
" raise ValueError(\"Missing gradient information.\")\n",
" gradient = gradient[..., np.newaxis]\n",
"\n",
" if gradient.shape[0] != 4:\n",
" gradient = gradient.T\n",
" elif gradient.shape == (4, 4):\n",
" print(\"Warning: make sure gradient information is not transposed!\")\n",
"\n",
" with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
" retval = gradient_table(gradient[3, :], gradient[:3, :].T)\n",
" return retval"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "760f45fa-d38f-4121-b9a3-9be1a5aeb6d0",
"metadata": {},
"outputs": [],
"source": [
"class DKIModel:\n",
" \"\"\"A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel.\"\"\"\n",
"\n",
" __slots__ = (\"_model\", \"_S0\", \"_mask\")\n",
"\n",
" def __init__(self, gtab, S0=None, mask=None, **kwargs):\n",
" \"\"\"Instantiate the wrapped tensor model.\"\"\"\n",
" from dipy.reconst.dki import DiffusionKurtosisModel\n",
"\n",
" self._S0 = None\n",
" if S0 is not None:\n",
" self._S0 = np.clip(\n",
" S0.astype(\"float32\") / S0.max(),\n",
" a_min=1e-5,\n",
" a_max=1.0,\n",
" )\n",
" self._mask = mask\n",
" if mask is None and S0 is not None:\n",
" self._mask = self._S0 > np.percentile(self._S0, 35)\n",
"\n",
" if self._mask is not None:\n",
" self._S0 = self._S0[self._mask.astype(bool)]\n",
"\n",
" kwargs = {\n",
" k: v\n",
" for k, v in kwargs.items()\n",
" if k\n",
" in (\n",
" \"min_signal\",\n",
" \"return_S0_hat\",\n",
" \"fit_method\",\n",
" \"weighting\",\n",
" \"sigma\",\n",
" \"jac\",\n",
" )\n",
" }\n",
" self._model = DiffusionKurtosisModel(gtab, **kwargs)\n",
"\n",
" def fit(self, data, **kwargs):\n",
" \"\"\"Clean-up permitted args and kwargs, and call model's fit.\"\"\"\n",
" self._model = self._model.fit(data[self._mask, ...])\n",
"\n",
" def predict(self, gradient, **kwargs):\n",
" \"\"\"Propagate model parameters and call predict.\"\"\"\n",
" predicted = np.squeeze(\n",
" self._model.predict(\n",
" _rasb2dipy(gradient),\n",
" S0=self._S0,\n",
" )\n",
" )\n",
" if predicted.ndim == 3:\n",
" return predicted\n",
"\n",
" retval = np.zeros_like(self._mask, dtype=\"float32\")\n",
" retval[self._mask, ...] = predicted\n",
" return retval"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "63cce457-29ef-4439-9677-962ab3e019a2",
"metadata": {},
"outputs": [],
"source": [
"class DTIModel:\n",
" \"\"\"A wrapper of :obj:`dipy.reconst.dti.TensorModel.\"\"\"\n",
"\n",
" __slots__ = (\"_model\", \"_S0\", \"_mask\")\n",
"\n",
" def __init__(self, gtab, S0=None, mask=None, **kwargs):\n",
" \"\"\"Instantiate the wrapped tensor model.\"\"\"\n",
" from dipy.reconst.dti import TensorModel as DipyTensorModel\n",
"\n",
" self._S0 = None\n",
" if S0 is not None:\n",
" self._S0 = np.clip(\n",
" S0.astype(\"float32\") / S0.max(),\n",
" a_min=1e-5,\n",
" a_max=1.0,\n",
" )\n",
"\n",
" self._mask = mask\n",
" if mask is None and S0 is not None:\n",
" self._mask = self._S0 > np.percentile(self._S0, 35)\n",
"\n",
" if self._mask is not None:\n",
" self._S0 = self._S0[self._mask.astype(bool)]\n",
"\n",
" kwargs = {\n",
" k: v\n",
" for k, v in kwargs.items()\n",
" if k\n",
" in (\n",
" \"min_signal\",\n",
" \"return_S0_hat\",\n",
" \"fit_method\",\n",
" \"weighting\",\n",
" \"sigma\",\n",
" \"jac\",\n",
" )\n",
" }\n",
" self._model = DipyTensorModel(_rasb2dipy(gtab), **kwargs)\n",
"\n",
" def fit(self, data, **kwargs):\n",
" \"\"\"Fit the model chunk-by-chunk asynchronously.\"\"\"\n",
" self._model = self._model.fit(data[self._mask, ...])\n",
"\n",
" def predict(self, gradient, **kwargs):\n",
" \"\"\"Propagate model parameters and call predict.\"\"\"\n",
" predicted = np.squeeze(\n",
" self._model.predict(\n",
" _rasb2dipy(gradient),\n",
" S0=self._S0,\n",
" )\n",
" )\n",
" if predicted.ndim == 3:\n",
" return predicted\n",
"\n",
" retval = np.zeros_like(self._mask, dtype=\"float32\")\n",
" retval[self._mask, ...] = predicted\n",
" return retval"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a0538db7-79ee-46f4-91fd-e8def6eedafe",
"metadata": {},
"outputs": [],
"source": [
"class SparseFascicleModel:\n",
" \"\"\"\n",
" A wrapper of :obj:`dipy.reconst.sfm.SparseFascicleModel.\n",
" \"\"\"\n",
"\n",
" __slots__ = (\"_model\", \"_S0\", \"_mask\", \"_solver\")\n",
"\n",
" def __init__(self, gtab, S0=None, mask=None, solver=None, **kwargs):\n",
" \"\"\"Instantiate the wrapped model.\"\"\"\n",
" from dipy.reconst.sfm import SparseFascicleModel\n",
" from sklearn.gaussian_process import GaussianProcessRegressor\n",
"\n",
" self._S0 = None\n",
" if S0 is not None:\n",
" self._S0 = np.clip(\n",
" S0.astype(\"float32\") / S0.max(),\n",
" a_min=1e-5,\n",
" a_max=1.0,\n",
" )\n",
"\n",
" self._mask = mask\n",
" if mask is None and S0 is not None:\n",
" self._mask = self._S0 > np.percentile(self._S0, 35)\n",
"\n",
" if self._mask is not None:\n",
" self._S0 = self._S0[self._mask.astype(bool)]\n",
"\n",
" self._solver = solver\n",
" if solver is None:\n",
" self._solver = \"ElasticNet\"\n",
"\n",
" kwargs = {k: v for k, v in kwargs.items() if k in (\"solver\",)}\n",
" self._model = SparseFascicleModel(gtab, **kwargs)\n",
"\n",
" def fit(self, data, **kwargs):\n",
" \"\"\"Clean-up permitted args and kwargs, and call model's fit.\"\"\"\n",
" self._model = self._model.fit(data[self._mask, ...])\n",
"\n",
" def predict(self, gradient, **kwargs):\n",
" \"\"\"Propagate model parameters and call predict.\"\"\"\n",
" predicted = np.squeeze(\n",
" self._model.predict(\n",
" _rasb2dipy(gradient),\n",
" S0=self._S0,\n",
" )\n",
" )\n",
" if predicted.ndim == 3:\n",
" return predicted\n",
"\n",
" retval = np.zeros_like(self._mask, dtype=\"float32\")\n",
" retval[self._mask, ...] = predicted\n",
" return retval"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "98e05ec4-00cf-42ad-b566-9b52d3afb0cd",
"metadata": {},
"outputs": [],
"source": [
"model = DTIModel(\n",
" gtab=data.gradients,\n",
" S0=data.bzero,\n",
" mask=data.brainmask\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46fd2487-c8f5-4c05-833d-b5f790be0819",
"metadata": {},
"outputs": [],
"source": [
"data_train, data_test = data.logo_split(10)\n",
"model.fit(data_train[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5aede348-5726-4ba5-8d28-9cfdbc9caf4e",
"metadata": {},
"outputs": [],
"source": [
"predicted = model.predict(data_test[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75a77093-18f3-4388-a0a8-be3841c71d3e",
"metadata": {},
"outputs": [],
"source": [
"plot_dwi(predicted, data.affine, gradient=data_test[1]);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74b5228c-8220-420b-8149-26f2b58a8850",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 46ee4f5

Please sign in to comment.