diff --git a/.gitignore b/.gitignore index 2697c7c..68091bc 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,6 @@ results/ temp/ # Figures -figures/Fig \ No newline at end of file +figures/Fig + +figures/temp_* \ No newline at end of file diff --git a/figures/02_gradient.ipynb b/figures/02_gradient.ipynb index 987fa52..175061c 100644 --- a/figures/02_gradient.ipynb +++ b/figures/02_gradient.ipynb @@ -9,12 +9,92 @@ "\"\"\"Example for plotting gradient data\"\"\"\n", "import os.path as op\n", "from glob import glob\n", + "import itertools\n", "\n", + "from sklearn.cluster import KMeans\n", + "from sklearn.metrics import silhouette_samples, silhouette_score\n", "import matplotlib.pyplot as plt\n", "\n", "from utils import plot_gradient, plot_subcortical_gradient" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import nibabel as nib\n", + "import pandas as pd\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import matplotlib.gridspec as gridspec\n", + "import seaborn as sns\n", + "import numpy as np\n", + "\n", + "class SeabornFig2Grid():\n", + "\n", + " def __init__(self, seaborngrid, fig, subplot_spec):\n", + " self.fig = fig\n", + " self.sg = seaborngrid\n", + " self.subplot = subplot_spec\n", + " if isinstance(self.sg, sns.axisgrid.FacetGrid) or \\\n", + " isinstance(self.sg, sns.axisgrid.PairGrid):\n", + " self._movegrid()\n", + " elif isinstance(self.sg, sns.axisgrid.JointGrid):\n", + " self._movejointgrid()\n", + " self._finalize()\n", + "\n", + " def _movegrid(self):\n", + " \"\"\" Move PairGrid or Facetgrid \"\"\"\n", + " self._resize()\n", + " n = self.sg.axes.shape[0]\n", + " m = self.sg.axes.shape[1]\n", + " self.subgrid = gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=self.subplot)\n", + " for i in range(n):\n", + " for j in range(m):\n", + " self._moveaxes(self.sg.axes[i,j], self.subgrid[i,j])\n", + "\n", + " def _movejointgrid(self):\n", + " \"\"\" Move Jointgrid \"\"\"\n", + " h= self.sg.ax_joint.get_position().height\n", + " h2= self.sg.ax_marg_x.get_position().height\n", + " r = int(np.round(h/h2))\n", + " self._resize()\n", + " self.subgrid = gridspec.GridSpecFromSubplotSpec(r+1,r+1, subplot_spec=self.subplot)\n", + "\n", + " self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1])\n", + " self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1])\n", + " self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1])\n", + "\n", + " def _moveaxes(self, ax, gs):\n", + " #https://stackoverflow.com/a/46906599/4124317\n", + " ax.remove()\n", + " ax.figure=self.fig\n", + " self.fig.axes.append(ax)\n", + " self.fig.add_axes(ax)\n", + " ax._subplotspec = gs\n", + " ax.set_position(gs.get_position(self.fig))\n", + " ax.set_subplotspec(gs)\n", + "\n", + " def _finalize(self):\n", + " plt.close(self.sg.fig)\n", + " self.fig.canvas.mpl_connect(\"resize_event\", self._resize)\n", + " self.fig.canvas.draw()\n", + "\n", + " def _resize(self, evt=None):\n", + " self.sg.fig.set_size_inches(self.fig.get_size_inches())" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -69,26 +149,17 @@ "metadata": {}, "outputs": [], "source": [ - "jperaza_grad_path = op.abspath(\"../results/gradient\")\n", + "jperaza_grad_path = op.abspath(\"../results/gradient/GIFTI-NIfTI_files\")\n", "jperaza_grad_out_path = \"../figures/Fig/gradient\"\n", "jperaza_grad_lh_fnames = sorted(glob(op.join(jperaza_grad_path, \"*hemi-L_feature.func.gii\")))\n", "jperaza_grad_rh_fnames = sorted(glob(op.join(jperaza_grad_path, \"*hemi-R_feature.func.gii\")))\n", "jperaza_subcort_grad_fnames = sorted(glob(op.join(jperaza_grad_path, \"*_feature.nii.gz\")))\n", "jperaza_grad_fnames = zip(jperaza_grad_lh_fnames, jperaza_grad_rh_fnames)\n", "\n", - "plot_gradient(\"../data\", jperaza_grad_fnames, cbar=True, out_dir=jperaza_grad_out_path)\n", + "plot_gradient(\"../data\", jperaza_grad_fnames, cbar=False, layout='column', out_dir=jperaza_grad_out_path)\n", "# plot_subcortical_gradient(jperaza_subcort_grad_fnames, threshold_=0.01)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np" - ] - }, { "cell_type": "code", "execution_count": null, @@ -146,10 +217,282 @@ "ax2.set_xlabel(\"Component Difference ($C_{i} - C_{i+1}$)\")\n", "\n", "plt.tight_layout()\n", - "plt.savefig(op.join(\"./\", \"Fig\", \"Fig-S2.png\"), bbox_inches=\"tight\", dpi=500)\n", + "plt.savefig(op.join(\"./\", \"Fig\", \"Fig-S3.png\"), bbox_inches=\"tight\", dpi=1000)\n", "plt.show()\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sum(y1[:4])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gradients = np.load(\"../results/gradient/gradients.npy\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "template_dir = \"../data/templates\"\n", + "subcortical_fn = op.join(template_dir, \"rois-subcortical_mni152_mask.nii.gz\")\n", + "subcort_img = nib.load(subcortical_fn)\n", + "\n", + "full_vertices = 64984\n", + "hemi_vertices = full_vertices // 2\n", + "\n", + "subcort_dat = subcort_img.get_fdata()\n", + "subcort_mask = subcort_dat != 0\n", + "n_subcort_vox = np.where(subcort_mask)[0].shape[0]\n", + "\n", + "n_gradients = gradients.shape[1]\n", + "grad_lst = []\n", + "for i in range(n_gradients):\n", + " cort_grads = gradients[: gradients.shape[0] - n_subcort_vox, i]\n", + " grad_lst.append(cort_grads)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "labels_00 = list(np.load(\"./temp_segmentation/KMeans_00-04_labels.npy\"))\n", + "labels_01 = list(np.load(\"./temp_segmentation/KMeans_01-04_labels.npy\"))\n", + "labels_02 = list(np.load(\"./temp_segmentation/KMeans_02-04_labels.npy\"))\n", + "labels_03 = list(np.load(\"./temp_segmentation/KMeans_03-04_labels.npy\"))\n", + "labels_04 = list(np.load(\"./temp_segmentation/KMeans_04-04_labels.npy\"))\n", + "labels_05 = list(np.load(\"./temp_segmentation/KMeans_05-04_labels.npy\"))\n", + "labels_06 = list(np.load(\"./temp_segmentation/KMeans_06-04_labels.npy\"))\n", + "labels_07 = list(np.load(\"./temp_segmentation/KMeans_07-04_labels.npy\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grad_df = pd.DataFrame(np.array(grad_lst).T, columns=[f\"Gradient {i+1}\" for i in range(n_gradients)])\n", + "grad_df[\"label_00\"] = labels_00\n", + "grad_df[\"label_01\"] = labels_01\n", + "grad_df[\"label_02\"] = labels_02\n", + "grad_df[\"label_03\"] = labels_03\n", + "grad_df[\"label_04\"] = labels_04\n", + "grad_df[\"label_05\"] = labels_05\n", + "grad_df[\"label_06\"] = labels_06\n", + "grad_df[\"label_07\"] = labels_07\n", + "grad_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.set(style=\"whitegrid\", font_scale=2)\n", + "pal = [\"#0174BE\", \"#008170\", \"#D0D4CA\", \"#B31312\"] # \"#ED7D31\"\n", + "n_rows, n_columns = 4, 4\n", + "vars = [f\"Gradient {i+1}\" for i in range(9)]\n", + "\n", + "for row, column in itertools.product(range(n_rows), range(n_columns)):\n", + " if row < column:\n", + " x, y = vars[column], vars[row]\n", + " figure = plt.figure(figsize=(2, 2))\n", + " # figure, ax = plt.subplots(figsize=(3, 3))\n", + "\n", + " g = sns.jointplot(\n", + " data=grad_df, \n", + " x=x, \n", + " y=y, \n", + " hue=\"label\", \n", + " palette=pal, \n", + " alpha=1, \n", + " s=3,\n", + " legend=False,\n", + " )\n", + " g.ax_joint.collections[0].set_edgecolor('none')\n", + " g.ax_joint.collections[0].set_linewidth(0)\n", + " \n", + " if row == 0:\n", + " g.ax_marg_y.set_ylim(-5, 7)\n", + " if row == 1:\n", + " g.ax_marg_y.set_ylim(-3, 3)\n", + " if row == 2:\n", + " g.ax_marg_y.set_ylim(-2, 4)\n", + " if row == 3:\n", + " g.ax_marg_y.set_ylim(-2, 2)\n", + " if column == 1:\n", + " g.ax_marg_x.set_xlim(-3, 3)\n", + " if column == 2:\n", + " g.ax_marg_x.set_xlim(-2, 4)\n", + " if column == 3:\n", + " g.ax_marg_x.set_xlim(-2, 2)\n", + "\n", + "\n", + " plt.savefig(f\"./temp_fig/{row}-{column}_mutidimensional.png\", dpi=100)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pal = [\"#0174BE\", \"#008170\", \"#D0D4CA\", \"#B31312\"]\n", + "for comp in [0, 1, 2, 3, 4, 5, 6, 7]:\n", + " figure = plt.figure(figsize=(2, 2))\n", + " g = sns.scatterplot(\n", + " data=grad_df, \n", + " x=\"Gradient 3\", \n", + " y=\"Gradient 1\", \n", + " hue=f\"label_{comp:02d}\", \n", + " palette=pal, \n", + " alpha=1, \n", + " s=3,\n", + " edgecolors=None,\n", + " linewidth=0,\n", + " legend=False,\n", + " )\n", + " g.set_axis_off()\n", + " plt.savefig(f\"./temp_fig/2D-gradient_G1:G{comp+1}.png\", bbox_inches=\"tight\", dpi=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import metrics\n", + "# comps = [0, 1, 2, 3, 4, 5, 6, 7]\n", + "rows = []\n", + "for i in range(8):\n", + " columns = []\n", + " for j in range(8):\n", + " labels_pred = grad_df[f\"label_{i:02d}\"].to_list()\n", + " labels_true = grad_df[f\"label_{j:02d}\"].to_list()\n", + " nmi = metrics.normalized_mutual_info_score(labels_true, labels_pred)\n", + " columns.append(nmi)\n", + " rows.append(columns)\n", + "corr = np.array(rows)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "label_list = [\"G1\", \"G1:G2\", \"G1:G3\", \"G1:G4\", \"G1:G5\", \"G1:G6\", \"G1:G7\", \"G1:G8\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask = np.zeros_like(corr)\n", + "mask[np.triu_indices_from(mask)] = True\n", + "with sns.axes_style(\"white\"):\n", + " f, ax = plt.subplots(figsize=(7, 5))\n", + " ax = sns.heatmap(corr, mask=mask, vmin=0, vmax=1, annot=True, square=True)\n", + "\n", + " ax.set_xticklabels(label_list)\n", + " ax.set_yticklabels(label_list)\n", + " ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=14)\n", + " ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=14)\n", + "plt.savefig(f\"./temp_fig/2D-gradient_NMI.png\", bbox_inches=\"tight\", dpi=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from surfplot.utils import add_fslr_medial_wall\n", + "from matplotlib.colors import ListedColormap\n", + "from neuromaps.datasets import fetch_fslr\n", + "from surfplot import Plot\n", + "from brainspace.datasets import load_parcellation\n", + "\n", + "surfaces = fetch_fslr()\n", + "lh, rh = surfaces['inflated']\n", + "\n", + "full_vertices = 64984\n", + "hemi_vertices = full_vertices // 2\n", + "prin_grad = add_fslr_medial_wall(np.array(labels)) # Add medial wall for plotting\n", + "labels_lh, labels_rh = prin_grad[:hemi_vertices], prin_grad[hemi_vertices:full_vertices]\n", + "\n", + "pal = [\"#0174BE\", \"#008170\", \"#D0D4CA\", \"#B31312\"]\n", + "for region in range(4):\n", + " # zero-out all regions except 71 and 72\n", + " map_lh = np.zeros_like(labels_lh)\n", + " map_lh[labels_lh==region] = 1\n", + " map_rh = np.zeros_like(labels_rh)\n", + " map_rh[labels_rh==region] = 1\n", + "\n", + " cmap = ListedColormap(pal[region], 'regions', N=1)\n", + " p = Plot(rh, views='medial')\n", + " p.add_layer(map_rh, cmap=cmap, cbar=False)\n", + " p.add_layer(map_rh, cmap='Greys', as_outline=True, cbar=False)\n", + "\n", + " fig = p.build()\n", + " plt.savefig(f\"./temp_fig/{region}-gradient.tiff\", bbox_inches=\"tight\", dpi=500)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "n_rows, n_columns = 4, 4\n", + "figure1 = plt.figure(figsize=(12, 12))\n", + "gs = GridSpec(\n", + " 3, \n", + " 3, \n", + " figure=figure1, \n", + ")\n", + "\n", + "for row in range(n_rows):\n", + " for column in reversed(range(n_columns)):\n", + " if row < column:\n", + " print(row, column, row*(n_columns-1) + column - 1)\n", + " # fig = plt.figure(figsize=(4, 4))\n", + " img1 = mpimg.imread(f\"./temp_fig/{row}-{column}_mutidimensional.png\")\n", + "\n", + " ax = figure1.add_subplot(gs[row*(n_columns-1) + column - 1], aspect=\"equal\")\n", + " # gs.update(left=0.55, right=0.98, hspace=0.05)\n", + " ax.imshow(img1)\n", + " # if gradient_row == 3:\n", + " # ax1.set_title('Gradient {}'.format(i))\n", + " ax.set_axis_off()\n", + "\n", + "plt.subplots_adjust(wspace=-0.25, hspace=-0.18)\n", + "plt.savefig(f\"./temp_fig/mutidimensional.png\", bbox_inches=\"tight\", dpi=1000)\n", + "plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/figures/03_segmentation.ipynb b/figures/03_segmentation.ipynb index 859c0b7..6b16d4f 100644 --- a/figures/03_segmentation.ipynb +++ b/figures/03_segmentation.ipynb @@ -11,9 +11,16 @@ "from glob import glob\n", "import pickle\n", "\n", + "import numpy as np\n", "import matplotlib.image as mpimg\n", "from matplotlib.gridspec import GridSpec\n", "import matplotlib.pyplot as plt\n", + "from nibabel import GiftiImage\n", + "from nibabel.gifti import GiftiDataArray\n", + "from surfplot.utils import add_fslr_medial_wall\n", + "from segmentation import KDESegmentation, KMeansSegmentation, PCTLSegmentation\n", + "import nibabel as nib\n", + "import numpy as np\n", "\n", "from utils import plot_gradient" ] @@ -28,6 +35,34 @@ "figures_dir = op.abspath(\"../figures\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gradients = np.load(\"../results/gradient/gradients.npy\")\n", + "\n", + "template_dir = \"../data/templates\"\n", + "subcortical_fn = op.join(template_dir, \"rois-subcortical_mni152_mask.nii.gz\")\n", + "subcort_img = nib.load(subcortical_fn)\n", + "\n", + "full_vertices = 64984\n", + "hemi_vertices = full_vertices // 2\n", + "\n", + "subcort_dat = subcort_img.get_fdata()\n", + "subcort_mask = subcort_dat != 0\n", + "n_subcort_vox = np.where(subcort_mask)[0].shape[0]\n", + "\n", + "n_gradients = gradients.shape[1]\n", + "grad_lst = []\n", + "for i in range(n_gradients):\n", + " cort_grads = gradients[: gradients.shape[0] - n_subcort_vox, i]\n", + " grad_lst.append(cort_grads)\n", + "\n", + "gradient = grad_lst[0]" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -50,13 +85,13 @@ "metadata": {}, "outputs": [], "source": [ - "percent_grad_seg_path = \"../results/segmentation/pct\"\n", + "percent_grad_seg_path = \"../results/segmentation/PCT_gradient-maps\"\n", "percent_grad_out_path = \"../figures/Fig/segmentation/pct\"\n", - "percent_grad_seg_lh_fnames = sorted(glob(op.join(percent_grad_seg_path, \"*Percentile*_desc-Bin*-L_feature.func.gii\")))\n", - "percent_grad_seg_rh_fnames = sorted(glob(op.join(percent_grad_seg_path, \"*Percentile*_desc-Bin*-R_feature.func.gii\")))\n", + "percent_grad_seg_lh_fnames = sorted(glob(op.join(percent_grad_seg_path, \"*PCT*_desc-C*-L_feature.func.gii\")))\n", + "percent_grad_seg_rh_fnames = sorted(glob(op.join(percent_grad_seg_path, \"*PCT*_desc-C*-R_feature.func.gii\")))\n", "percent_grad_seg_fnames = zip(percent_grad_seg_lh_fnames, percent_grad_seg_rh_fnames)\n", "\n", - "plot_gradient(data_dir, percent_grad_seg_fnames, title=False, out_dir=percent_grad_out_path)" + "plot_gradient(data_dir, percent_grad_seg_fnames, cmap=\"YlOrRd\", color_range=(0,1), title=False, out_dir=percent_grad_out_path)" ] }, { @@ -73,13 +108,13 @@ "metadata": {}, "outputs": [], "source": [ - "kmeans_grad_seg_path = \"../results/segmentation/kmeans\"\n", + "kmeans_grad_seg_path = \"../results/segmentation/KMeans_gradient-maps\"\n", "kmeans_grad_out_path = \"../figures/Fig/segmentation/kmeans\"\n", - "kmeans_grad_seg_lh_fnames = sorted(glob(op.join(kmeans_grad_seg_path, \"*KMeans*_desc-Bin*-L_feature.func.gii\")))\n", - "kmeans_grad_seg_rh_fnames = sorted(glob(op.join(kmeans_grad_seg_path, \"*KMeans*_desc-Bin*-R_feature.func.gii\")))\n", + "kmeans_grad_seg_lh_fnames = sorted(glob(op.join(kmeans_grad_seg_path, \"*KMeans*_desc-C*-L_feature.func.gii\")))\n", + "kmeans_grad_seg_rh_fnames = sorted(glob(op.join(kmeans_grad_seg_path, \"*KMeans*_desc-C*-R_feature.func.gii\")))\n", "kmeans_grad_seg_fnames = zip(kmeans_grad_seg_lh_fnames, kmeans_grad_seg_rh_fnames)\n", "\n", - "plot_gradient(data_dir, kmeans_grad_seg_fnames, title=False, out_dir=kmeans_grad_out_path)" + "plot_gradient(data_dir, kmeans_grad_seg_fnames, cmap=\"YlOrRd\", color_range=(0,1), title=False, out_dir=kmeans_grad_out_path)" ] }, { @@ -96,13 +131,76 @@ "metadata": {}, "outputs": [], "source": [ - "kde_grad_seg_path = \"../results/segmentation/kde\"\n", + "kde_grad_seg_path = \"../results/segmentation/KDE_gradient-maps\"\n", "kde_grad_out_path = \"../figures/Fig/segmentation/kde\"\n", - "kde_grad_seg_lh_fnames = sorted(glob(op.join(kde_grad_seg_path, \"*KDE*_desc-Bin*-L_feature.func.gii\")))\n", - "kde_grad_seg_rh_fnames = sorted(glob(op.join(kde_grad_seg_path, \"*KDE*_desc-Bin*-L_feature.func.gii\")))\n", + "kde_grad_seg_lh_fnames = sorted(glob(op.join(kde_grad_seg_path, \"*KDE*_desc-C*-L_feature.func.gii\")))\n", + "kde_grad_seg_rh_fnames = sorted(glob(op.join(kde_grad_seg_path, \"*KDE*_desc-C*-R_feature.func.gii\")))\n", + "print(kde_grad_seg_lh_fnames)\n", + "print(kde_grad_seg_rh_fnames)\n", "kde_grad_seg_fnames = zip(kde_grad_seg_lh_fnames, kde_grad_seg_rh_fnames)\n", "\n", - "plot_gradient(data_dir, kde_grad_seg_fnames, title=False, out_dir=kde_grad_out_path)" + "plot_gradient(data_dir, kde_grad_seg_fnames, cmap=\"YlOrRd\", color_range=(0,1), title=False, out_dir=kde_grad_out_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gc\n", + "\n", + "from neuromaps.datasets import fetch_fslr\n", + "from surfplot import Plot\n", + "from surfplot.utils import threshold\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def plot_surf_maps(lh_grad, rh_grad, color_range, cmap, dpi, data_dir, out_filename):\n", + " neuromaps_dir = op.join(data_dir, \"neuromaps\")\n", + "\n", + " surfaces = fetch_fslr(density=\"32k\", data_dir=neuromaps_dir)\n", + " lh, rh = surfaces[\"inflated\"]\n", + " sulc_lh, sulc_rh = surfaces[\"sulc\"]\n", + "\n", + " p = Plot(lh, views=\"lateral\")\n", + " p.add_layer({\"left\": sulc_lh}, cmap=\"binary_r\", cbar=False)\n", + " p.add_layer({\"left\": lh_grad}, cmap=cmap, cbar=False, color_range=color_range,)\n", + " fig = p.build()\n", + "\n", + " fig.savefig(out_filename, bbox_inches=\"tight\", dpi=dpi, transparent=True)\n", + " fig = None\n", + " plt.close()\n", + " gc.collect()\n", + " plt.clf()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "segmentations = [3, 17, 32]\n", + "seg_sols = [[0, 1, 2], [0, 8, 16], [0, 15, 31]]\n", + "methods = [\"PCT\", \"KMeans\", \"KDE\"]\n", + "grad_out_path = \"../figures/Fig/segmentation\"\n", + "\n", + "for seg_i, segmentation in enumerate(segmentations):\n", + " print(f\"Segmentation: {segmentation}\")\n", + " for seg_sol in seg_sols[seg_i]:\n", + " for method in methods:\n", + " grad_seg_path = f\"../results/segmentation/{method}_gradient-maps\"\n", + " grad_seg_lh_fnames = sorted(glob(op.join(grad_seg_path, f\"*{method}{segmentation:02d}_desc-C{seg_sol:02d}_*-L_feature.func.gii\")))\n", + " grad_seg_rh_fnames = sorted(glob(op.join(grad_seg_path, f\"*{method}{segmentation:02d}_desc-C{seg_sol:02d}_*-R_feature.func.gii\")))\n", + " print(grad_seg_lh_fnames)\n", + " print(grad_seg_rh_fnames)\n", + " grad_seg_fnames = zip(grad_seg_lh_fnames, grad_seg_rh_fnames)\n", + "\n", + " # plot_gradient(data_dir, grad_seg_fnames, cmap=\"YlOrRd\", color_range=(0,1), views=\"lateral\", title=False, out_dir=grad_out_path)\n", + " for lh_grad, rh_grad in grad_seg_fnames:\n", + " out_filename = op.join(grad_out_path, f\"{method}{segmentation:02d}_C{seg_sol:02d}.tiff\")\n", + " plot_surf_maps(lh_grad, rh_grad, (0,1), \"YlOrRd\", 100, data_dir, out_filename)\n" ] }, { @@ -119,7 +217,7 @@ "img_lbs = [\"PCT\", \"KMeans\", \"KDE\"]\n", "step = 0\n", "row = 0\n", - "for segment_size in range(3, 33):\n", + "for segment_size in range(2, 33):\n", " pct_files = sorted(glob(op.join(figures_dir, \"Fig\", \"segmentation\", \"pct\", f\"*{segment_size:02d}-*.tiff\")))\n", " kms_files = sorted(glob(op.join(figures_dir, \"Fig\", \"segmentation\", \"kmeans\", f\"*{segment_size:02d}-*.tiff\")))\n", " kde_files = sorted(glob(op.join(figures_dir, \"Fig\", \"segmentation\", \"kde\", f\"*{segment_size:02d}-*.tiff\")))\n", @@ -175,12 +273,92 @@ "metadata": {}, "outputs": [], "source": [ - "with open(op.join(percent_grad_seg_path, \"pct_results.pkl\"), \"rb\") as results_file:\n", + "with open(op.join(\"../results/segmentation\", \"new_PCT_results.pkl\"), \"rb\") as results_file:\n", " pct_dict = pickle.load(results_file)\n", - "with open(op.join(kmeans_grad_seg_path, \"kmeans_results.pkl\"), \"rb\") as results_file:\n", + "with open(op.join(\"../results/segmentation\", \"new_KMeans_results.pkl\"), \"rb\") as results_file:\n", " kmeans_dict = pickle.load(results_file)\n", - "with open(op.join(kde_grad_seg_path, \"kde_results.pkl\"), \"rb\") as results_file:\n", - " kde_dict = pickle.load(results_file)" + "with open(op.join(\"../results/segmentation\", \"new_KDE_results.pkl\"), \"rb\") as results_file:\n", + " kde_dict = pickle.load(results_file)\n", + "\n", + "dict_list = [pct_dict, kmeans_dict, kde_dict]\n", + "label_list = [\"PCT\", \"KMeans\", \"KDE\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn import metrics\n", + "import seaborn as sns\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "methods_lst = []\n", + "seg_sol_lst = []\n", + "metrics_lst = []\n", + "for seg in range(31):\n", + " for i in range(len(dict_list)):\n", + " for j in range(i):\n", + " nmi = metrics.normalized_mutual_info_score(\n", + " dict_list[i][\"labels\"][seg],\n", + " dict_list[j][\"labels\"][seg],\n", + " )\n", + " methods_lst.append(f\"{label_list[i]} ~ {label_list[j]}\")\n", + " seg_sol_lst.append(seg+2)\n", + " metrics_lst.append(nmi)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_df = pd.DataFrame({\"Method\": methods_lst, \"Segmentation\": seg_sol_lst, \"NMI\": metrics_lst})\n", + "data_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.set_style(\"ticks\",{'axes.grid' : True})\n", + "fig, ax = plt.subplots(1, 1)\n", + "fig.set_size_inches(7, 4)\n", + "\n", + "sns.lineplot(\n", + " data=data_df,\n", + " x=\"Segmentation\", \n", + " y=\"NMI\", \n", + " hue=\"Method\", \n", + " style=\"Method\",\n", + " palette=\"rocket_r\",\n", + " markers=True, \n", + " dashes=False, \n", + " ax=ax,\n", + ")\n", + "\n", + "ax.set_ylabel('Normalized Mutual Information (NMI)', fontsize=14)\n", + "ax.set_xlabel('Segmentation Solution', fontsize=14)\n", + "ax.set_xticks(np.arange(2, 33, 2))\n", + "ax.set_xticklabels(np.arange(2, 33, 2), fontsize=13)\n", + "ax.tick_params(axis='y', labelsize=13)\n", + "legend = ax.legend()\n", + "legend.set_title('')\n", + "\n", + "fig.tight_layout()\n", + "fig.savefig(op.join(figures_dir, \"Fig\", \"Fig-S5.png\"), bbox_inches=\"tight\", dpi=1000)\n", + "plt.show()" ] }, { @@ -192,10 +370,10 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "n_segments = 30\n", - "min_n_segments = 3\n", + "n_segments = 31\n", + "min_n_segments = 2\n", "colors = [\"#05BFDB\", \"#088395\", \"#0A4D68\"]\n", - "\n", + "min_peak, max_peak = gradient.min(), gradient.max()\n", "boundaries = []\n", "for seg_i, n_segment in enumerate(range(min_n_segments, n_segments + min_n_segments)):\n", " \n", @@ -206,6 +384,7 @@ " for dict_i, results_dict in enumerate([kde_dict, kmeans_dict, pct_dict]):\n", " bound_arr = results_dict[\"boundaries\"][seg_i]\n", " peaks_arr = results_dict[\"peaks\"][seg_i]\n", + " peaks_arr[0], peaks_arr[-1] = min_peak, max_peak\n", "\n", " x = []\n", " y = []\n", @@ -246,6 +425,154 @@ " plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = \"/Users/jperaza/Documents/GitHub/gradient-decoding/results/segmentation\"\n", + "n_segments = 1\n", + "for method in [\"PCT\", \"KMeans\", \"KDE\"]:\n", + " results_fn = op.join(output_dir, f\"0_{method}_results.pkl\")\n", + " if not op.isfile(results_fn):\n", + " if method == \"PCT\":\n", + " # Percentile Segmentation\n", + " print(\"\\t\\tRunning Percentile Segmentation...\", flush=True)\n", + " segment_method = PCTLSegmentation(results_fn, n_segments)\n", + " elif method == \"KMeans\":\n", + " # K-Means\n", + " print(\"\\t\\tRunning K-Means Segmentation...\", flush=True)\n", + " segment_method = KMeansSegmentation(results_fn, n_segments)\n", + " elif method == \"KDE\":\n", + " # KDE\n", + " print(\"\\t\\tRunning KDE Segmentation...\", flush=True)\n", + " segment_method = KDESegmentation(results_fn, n_segments)\n", + " \n", + " results_dict = segment_method.fit(gradient)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = \"/Users/jperaza/Documents/GitHub/gradient-decoding/results/segmentation\"\n", + "for method in [\"PCT\", \"KMeans\", \"KDE\"]:\n", + " new_results_fn = op.join(output_dir, f\"0_{method}_results.pkl\")\n", + " old_results_fn = op.join(output_dir, f\"{method}_results.pkl\")\n", + " results_fn = op.join(output_dir, f\"new_{method}_results.pkl\")\n", + "\n", + " with open(new_results_fn, \"rb\") as results_file:\n", + " new_results_dict = pickle.load(results_file)\n", + " \n", + " with open(old_results_fn, \"rb\") as results_file:\n", + " old_results_dict = pickle.load(results_file)\n", + " \n", + " results_dict = old_results_dict.copy()\n", + " results_dict[\"segments\"].insert(0, new_results_dict[\"segments\"][0])\n", + " results_dict[\"boundaries\"].insert(0, new_results_dict[\"boundaries\"][0])\n", + " results_dict[\"peaks\"].insert(0, new_results_dict[\"peaks\"][0])\n", + " results_dict[\"labels\"].insert(0, new_results_dict[\"labels\"][0])\n", + "\n", + " with open(results_fn, \"wb\") as segmentation_file:\n", + " pickle.dump(results_dict, segmentation_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nibabel as nib\n", + "\n", + "def scores_to_gii(scores, lh_fn, rh_fn):\n", + " full_vertices = 64984\n", + " hemi_vertices = full_vertices // 2\n", + "\n", + " grad_map_full = add_fslr_medial_wall(scores, split=False)\n", + " grad_map_lh, grad_map_rh = (\n", + " grad_map_full[:hemi_vertices],\n", + " grad_map_full[hemi_vertices:],\n", + " )\n", + " grad_map_lh = np.float32(grad_map_lh)\n", + " grad_map_rh = np.float32(grad_map_rh)\n", + "\n", + " grad_img_lh = GiftiImage()\n", + " grad_img_rh = GiftiImage()\n", + " grad_img_lh.add_gifti_data_array(GiftiDataArray(grad_map_lh))\n", + " grad_img_rh.add_gifti_data_array(GiftiDataArray(grad_map_rh))\n", + "\n", + " # Write cortical gradient to Gifti files\n", + " nib.save(grad_img_lh, lh_fn)\n", + " nib.save(grad_img_rh, rh_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(\"/Users/jperaza/Documents/GitHub/gradient-decoding/results/segmentation/segments.pkl\", \"rb\") as segmentation_file:\n", + " grad_seg_dict = pickle.load(segmentation_file)\n", + "\n", + "for method in [\"PCT\", \"KMeans\", \"KDE\"]:\n", + " out_dir = f\"/Users/jperaza/Documents/GitHub/gradient-decoding/results/segmentation/{method}_gradient-maps\"\n", + " grad_segments = grad_seg_dict[method]\n", + " for grad_segment in grad_segments:\n", + " seg_size = len(grad_segment)\n", + " for map_i, grad_map in enumerate(grad_segment):\n", + " print(seg_size, map_i)\n", + " grad_map_lh_fn = op.join(\n", + " out_dir,\n", + " f\"source-{method}{seg_size:02d}_desc-C{map_i:02d}_space-fsLR_den-32k_hemi-L_feature.func.gii\",\n", + " )\n", + " grad_map_rh_fn = op.join(\n", + " out_dir,\n", + " f\"source-{method}{seg_size:02d}_desc-C{map_i:02d}_space-fsLR_den-32k_hemi-R_feature.func.gii\",\n", + " )\n", + "\n", + " #if not (op.isfile(grad_map_lh_fn) and op.isfile(grad_map_rh_fn)):\n", + " scores_to_gii(grad_map, grad_map_lh_fn, grad_map_rh_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for method in [\"PCT\", \"KMeans\", \"KDE\"]:\n", + " out_dir = f\"/Users/jperaza/Documents/GitHub/gradient-decoding/results/segmentation/{method}_confidence-maps\"\n", + " samples_arr_fn = op.join(\"../results\", \"segmentation\", f\"{method}_silhouette.npy\")\n", + " samples_arrays = np.load(samples_arr_fn)\n", + " \n", + " for map_i in range(samples_arrays.shape[0]):\n", + " map_array = samples_arrays[map_i, :]\n", + " seg_size = map_i + 2\n", + "\n", + " grad_map_lh_fn = op.join(\n", + " out_dir,\n", + " f\"source-{method}{seg_size:02d}_desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii\",\n", + " )\n", + " grad_map_rh_fn = op.join(\n", + " out_dir,\n", + " f\"source-{method}{seg_size:02d}_desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii\",\n", + " )\n", + "\n", + " scores_to_gii(map_array, grad_map_lh_fn, grad_map_rh_fn)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/figures/04_silhouette-samples.ipynb b/figures/04_silhouette-samples.ipynb index 860f828..53dc9e2 100644 --- a/figures/04_silhouette-samples.ipynb +++ b/figures/04_silhouette-samples.ipynb @@ -15,7 +15,6 @@ "from matplotlib.gridspec import GridSpec\n", "import pandas as pd\n", "import numpy as np\n", - "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "import matplotlib.cm as cm\n", "import seaborn as sns\n", @@ -54,18 +53,21 @@ "metadata": {}, "outputs": [], "source": [ - "silhouette_scores_df = pd.read_csv(op.join(results_dir, \"segmentation\", \"silhouette_scores.csv\"))\n", - "n_score = silhouette_scores_df.shape[0]\n", - "\n", - "mean_scores_df = pd.DataFrame()\n", - "mean_scores_df[\"segment_sizes\"] = silhouette_scores_df[\"segment_sizes\"].tolist()\n", - "segmentation_lst = [\"PCT\"] * n_score + [\"KMeans\"] * n_score + [\"KDE\"] * n_score\n", - "silhouette_score_lst = silhouette_scores_df[\"percentile\"].tolist() + silhouette_scores_df[\"kmeans\"].tolist() + silhouette_scores_df[\"kde\"].tolist()\n", - "\n", - "mean_scores_df = pd.concat([mean_scores_df] * 3)\n", - "mean_scores_df[\"segmentation\"] = segmentation_lst\n", - "mean_scores_df[\"silhouette_score\"] = silhouette_score_lst\n", - "mean_scores_df = mean_scores_df.reset_index()" + "hd_scores_df = pd.read_csv(op.join(results_dir, \"segmentation\", \"scores_high-dimensional.csv\"))\n", + "ld_scores_df = pd.read_csv(op.join(results_dir, \"segmentation\", \"scores_uni-dimensional.csv\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_dict = {\n", + " \"silhouette\": \"Mean Silhouette Coefficient\",\n", + " \"variance_ratio\": \"Variance Ratio\",\n", + " \"cluster_separation\": \"Cluster Separation\",\n", + "}" ] }, { @@ -75,31 +77,93 @@ "outputs": [], "source": [ "sns.set_style(\"ticks\",{'axes.grid' : True})\n", - "plt.figure(figsize=(6,4.3))\n", "\n", - "idxes = [0, 2, 4]\n", + "metrics = [\"silhouette\", \"variance_ratio\", \"cluster_separation\"]\n", + "\n", + "fig_1, axes_1 = plt.subplots(3, 1)\n", + "fig_1.set_size_inches(6.5, 15)\n", + "fig_2, axes_2 = plt.subplots(3, 1)\n", + "fig_2.set_size_inches(6.5, 15)\n", + "\n", "colors = [\"#0A4D68\", \"#088395\", \"#05BFDB\"]\n", "hue_order = [\"PCT\", \"KMeans\", \"KDE\"]\n", "\n", - "ax = sns.lineplot(\n", - " data=mean_scores_df,\n", - " x=\"segment_sizes\", \n", - " y=\"silhouette_score\", \n", - " hue=\"segmentation\", \n", - " style=\"segmentation\",\n", - " markers=True, \n", - " dashes=False, \n", - " hue_order=hue_order,\n", - " palette=colors,\n", - ")\n", - "ax.set_xlabel('Segment Solution', fontsize=16)\n", - "plt.xticks(fontsize=14)\n", - "ax.set_ylabel('Mean Silhouette Coefficient', fontsize=16)\n", - "plt.yticks(fontsize=14)\n", - "ax.legend(title='Segmentation')\n", - "\n", - "#plt.savefig(op.join(figures_dir, \"Fig\", \"silhouette\", \"silhouette_scores.eps\"), bbox_inches=\"tight\")\n", - "# plt.savefig(op.join(figures_dir, \"Fig\", \"Fig-05.eps\"), bbox_inches=\"tight\")" + "for met_i, metric in enumerate(metrics):\n", + " sns.lineplot(\n", + " data=ld_scores_df,\n", + " x=\"segment\", \n", + " y=metric, \n", + " hue=\"method\", \n", + " style=\"method\",\n", + " markers=True, \n", + " dashes=False, \n", + " hue_order=hue_order,\n", + " palette=colors,\n", + " ax=axes_1[met_i],\n", + " )\n", + "\n", + " sns.lineplot(\n", + " data=hd_scores_df,\n", + " x=\"segment\", \n", + " y=metric, \n", + " hue=\"component\", \n", + " style=\"component\",\n", + " palette=\"rocket_r\",\n", + " markers=True, \n", + " dashes=False, \n", + " ax=axes_2[met_i],\n", + " )\n", + " if met_i == 2:\n", + " # axes[met_i, 1].set_ylim([1, 6])\n", + " axes_1[met_i].set_xlabel('Segment Solution', fontsize=18)\n", + " axes_2[met_i].set_xlabel('Segment Solution', fontsize=18)\n", + " axes_1[met_i].tick_params(axis='x', labelsize=16)\n", + " axes_2[met_i].tick_params(axis='x', labelsize=16)\n", + " legend_1 = axes_2[met_i].legend(title='Components', fontsize=14)\n", + " legend_1.get_title().set_fontsize('16')\n", + " legend_2 = axes_1[met_i].legend(title='Segmentation', fontsize=14)\n", + " legend_2.get_title().set_fontsize('16')\n", + " else:\n", + " axes_1[met_i].set_xticklabels([])\n", + " axes_1[met_i].set_xlabel('')\n", + " axes_2[met_i].set_xticklabels([])\n", + " axes_2[met_i].set_xlabel('')\n", + " axes_2[met_i].legend_.remove()\n", + " axes_1[met_i].legend_.remove()\n", + "\n", + " if met_i == 0:\n", + " axes_2[met_i].set_title(\"High-Dimensional K-means Clustering\", fontsize=20)\n", + " axes_1[met_i].set_title(\"One-Dimensional Segmentation\", fontsize=20)\n", + " \n", + " axes_1[met_i].set_ylabel(metric_dict[metric], fontsize=18)\n", + " axes_2[met_i].set_ylabel(metric_dict[metric], fontsize=18)\n", + " axes_2[met_i].tick_params(axis='y', labelsize=16)\n", + " axes_1[met_i].tick_params(axis='y', labelsize=16)\n", + "\n", + " if met_i != 2:\n", + " index_1 = hd_scores_df[metric].idxmax()\n", + " seg_1 = hd_scores_df.loc[index_1, 'segment']\n", + " value_1 = hd_scores_df[metric].max()\n", + " index_0 = ld_scores_df[metric].idxmax()\n", + " seg_0 = ld_scores_df.loc[index_0, 'segment']\n", + " value_0 = ld_scores_df[metric].max()\n", + " else:\n", + " index_1 = hd_scores_df[metric].idxmin()\n", + " seg_1 = hd_scores_df.loc[index_1, 'segment']\n", + " value_1 = hd_scores_df[metric].min()\n", + " index_0 = ld_scores_df[metric].idxmin()\n", + " seg_0 = ld_scores_df.loc[index_0, 'segment']\n", + " value_0 = ld_scores_df[metric].min()\n", + " \n", + " axes_2[met_i].scatter(seg_1, value_1, facecolors='none', edgecolors='red', s=150)\n", + " axes_1[met_i].scatter(seg_0, value_0, facecolors='none', edgecolors='red', s=150)\n", + " \n", + "# plt.savefig(op.join(figures_dir, \"Fig\", \"silhouette\", \"silhouette_scores.eps\"), bbox_inches=\"tight\")\n", + "# plt.tight_layout()\n", + "fig_1.tight_layout()\n", + "fig_2.tight_layout()\n", + "fig_1.savefig(op.join(figures_dir, \"Fig\", \"Fig-04.eps\"))\n", + "fig_2.savefig(op.join(figures_dir, \"Fig\", \"Fig-10.eps\"))" ] }, { @@ -119,11 +183,11 @@ "# Get samples and labels data\n", "samples_arrays = []\n", "labels_lst = []\n", - "for method in [\"pct\", \"kmeans\", \"kde\"]:\n", - " samples_arr_fn = op.join(results_dir, \"segmentation\", method, f\"{method}_samples.npy\")\n", + "for method in [\"PCT\", \"KMeans\", \"KDE\"]:\n", + " samples_arr_fn = op.join(results_dir, \"segmentation\", f\"{method}_silhouette.npy\")\n", " samples_arrays.append(np.load(samples_arr_fn))\n", "\n", - " results_dict_fn = op.join(results_dir, \"segmentation\", method, f\"{method}_results.pkl\")\n", + " results_dict_fn = op.join(results_dir, \"segmentation\", f\"{method}_results.pkl\")\n", " with open(results_dict_fn, \"rb\") as results_dict_file:\n", " results_dict = pickle.load(results_dict_file)\n", " labels_lst.append(results_dict[\"labels\"])" @@ -136,8 +200,8 @@ "outputs": [], "source": [ "# Create a pandas dataframe (violin_df) for violin plots.\n", - "min_n_segments = 3\n", - "n_segments = 30\n", + "min_n_segments = 2\n", + "n_segments = 31\n", "segment_sizes = np.arange(min_n_segments, n_segments + min_n_segments)\n", "\n", "violin_df = pd.DataFrame()\n", @@ -165,21 +229,24 @@ "source": [ "sns.set_style(\"ticks\")\n", "colors = [\"#0A4D68\", \"#088395\", \"#05BFDB\"]\n", + "hue_order = [\"PCT\", \"KMeans\", \"KDE\"]\n", + "\n", + "yticks = np.array([-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8])\n", + "\n", "fontsize = 12\n", "color = plt.colormaps[\"viridis\"]\n", "ort = \"v\"\n", "dy = \"sample_scores\"\n", "dx = \"segmentation\"\n", "\n", - "for segm_i in range(30):\n", - " segment_size = segm_i + 3\n", + "for segm_i in range(31):\n", + " segment_size = segm_i + 2\n", "\n", " fig, axes = plt.subplots(1, 4)\n", " fig.set_size_inches(7.5, 3)\n", "\n", - " for method_i, method in enumerate([\"Percentile\", \"KMeans\", \"KDE\"]):\n", - " method_name = method.lower() if method != \"Percentile\" else \"pct\"\n", - " with open(op.join(results_dir, \"segmentation\", method_name, f\"{method_name}_results.pkl\"), \"rb\") as results_file:\n", + " for method_i, method in enumerate([\"PCT\", \"KMeans\", \"KDE\"]):\n", + " with open(op.join(results_dir, \"segmentation\", f\"{method}_results.pkl\"), \"rb\") as results_file:\n", " results_dict = pickle.load(results_file)\n", " bound_arr = results_dict[\"boundaries\"][segm_i]\n", " \n", @@ -207,7 +274,6 @@ " alpha=1,\n", " )\n", "\n", - " yticks = np.arange(-1, 1.5, 0.5)\n", " imb_axis.set_yticks(yticks)\n", " imb_axis.axes.yaxis.set_ticklabels([])\n", " imb_axis.set_xticks([x_min, x_med, x_max])\n", @@ -217,7 +283,8 @@ " imb_axis.set_xlabel(\"Cluster Imbalance\", fontsize=fontsize)\n", " imb_axis.set_title(method)\n", " # The vertical line for average silhouette score of all the values\n", - " imb_axis.axhline(y=silhouette_scores_df[method.lower()][segm_i], color=\"black\", linestyle=\"--\")\n", + " y_line = ld_scores_df[ld_scores_df.segment == segment_size][\"silhouette\"].values[0]\n", + " imb_axis.axhline(y=y_line, color=\"black\", linestyle=\"--\")\n", " imb_axis.grid(axis='y', which='major', color='gray', alpha=0.5)\n", "\n", " pt.half_violinplot(\n", @@ -272,8 +339,8 @@ " # axes[0].set_xlabel(\"Silhouette Distribution\", fontsize=fontsize)\n", " axes[0].set_ylabel(\"Silhouette Coefficient\", fontsize=fontsize)\n", " axes[0].set_xlabel(\"\")\n", - " axes[0].set_yticks([-1, -0.5, 0, 0.5, 1])\n", - " axes[0].set_yticklabels([-1, -0.5, 0, 0.5, 1], fontsize=fontsize-2)\n", + " axes[0].set_yticks(yticks)\n", + " axes[0].set_yticklabels(yticks, fontsize=fontsize-2)\n", " # axes[0].set_xticklabels([])\n", " axes[0].grid(axis='y', which='major', color='gray', alpha=0.5)\n", " \n", @@ -307,6 +374,8 @@ "sns.set_style(\"ticks\")\n", "\n", "colors = [\"#0A4D68\", \"#088395\", \"#05BFDB\"]\n", + "hue_order = [\"PCT\", \"KMeans\", \"KDE\"]\n", + "yticks = np.array([-0.6, -0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8])\n", "\n", "ort = \"v\"\n", "dy = \"sample_scores\"\n", @@ -317,11 +386,10 @@ "for segm_i, segment_size in enumerate([3, 17, 32]):\n", " vio_axis = axes_tpl[segm_i, 0]\n", "\n", - " for method_i, method in enumerate([\"Percentile\", \"KMeans\", \"KDE\"]):\n", - " method_name = method.lower() if method != \"Percentile\" else \"pct\"\n", - " with open(op.join(results_dir, \"segmentation\", method_name, f\"{method_name}_results.pkl\"), \"rb\") as results_file:\n", + " for method_i, method in enumerate([\"PCT\", \"KMeans\", \"KDE\"]):\n", + " with open(op.join(results_dir, \"segmentation\", f\"new_{method}_results.pkl\"), \"rb\") as results_file:\n", " results_dict = pickle.load(results_file)\n", - " bound_arr = results_dict[\"boundaries\"][segment_size-3]\n", + " bound_arr = results_dict[\"boundaries\"][segment_size-2]\n", " \n", " imb_axis = axes_tpl[segm_i, method_i + 1]\n", "\n", @@ -333,7 +401,7 @@ " boun_i, boun_j = (bound_arr[cluster_i], bound_arr[cluster_i + 1])\n", " # Aggregate the silhouette scores for samples belonging to\n", " # cluster i, and sort them\n", - " ith_cluster_silhouette_values = samples_arrays[method_i][segment_size-3, labels_lst[method_i][segment_size-3] == cluster_i]\n", + " ith_cluster_silhouette_values = samples_arrays[method_i][segment_size-2, labels_lst[method_i][segment_size-2] == cluster_i]\n", "\n", " ith_cluster_silhouette_values.sort()\n", "\n", @@ -348,7 +416,7 @@ " alpha=1,\n", " )\n", "\n", - " yticks = np.arange(-1, 1.5, 0.5)\n", + " # yticks = np.arange(-1, 1.5, 0.5)\n", " imb_axis.set_yticks(yticks)\n", " imb_axis.set_xticks([]) # Clear the yaxis labels / ticks\n", " imb_axis.axes.yaxis.set_ticklabels([])\n", @@ -358,7 +426,8 @@ " imb_axis.set_xticklabels([x_min, x_med, x_max], fontsize=18)\n", "\n", " # The vertical line for average silhouette score of all the values\n", - " imb_axis.axhline(y=silhouette_scores_df[method.lower()][segment_size-3], color=\"black\", linestyle=\"--\")\n", + " y_line = ld_scores_df[(ld_scores_df.segment == segment_size) & (ld_scores_df.method == method)][\"silhouette\"].values[0]\n", + " imb_axis.axhline(y=y_line, color=\"black\", linestyle=\"--\")\n", " imb_axis.grid(axis='y', which='major', color='gray', alpha=0.5)\n", "\n", " pt.half_violinplot(\n", @@ -411,8 +480,8 @@ "\n", " vio_axis.set_ylabel(\"\")\n", " vio_axis.set_xlabel(\"\")\n", - " vio_axis.set_yticks([-1, -0.5, 0, 0.5, 1])\n", - " vio_axis.set_yticklabels([-1, -0.5, 0, 0.5, 1], fontsize=18)\n", + " vio_axis.set_yticks(yticks)\n", + " vio_axis.set_yticklabels(yticks, fontsize=18)\n", " vio_axis.set_xticklabels([])\n", " vio_axis.grid(axis='y', which='major', color='gray', alpha=0.5)\n", "\n", @@ -449,12 +518,12 @@ "metadata": {}, "outputs": [], "source": [ - "percentile_samples_seg_path = op.join(results_dir, \"segmentation\", \"silhouette\", \"pct\")\n", + "percentile_samples_seg_path = op.join(results_dir, \"segmentation\", \"PCT_confidence-maps\")\n", "percentile_samples_seg_lh_fnames = sorted(glob(op.join(percentile_samples_seg_path, \"*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii\")))\n", "percentile_samples_seg_rh_fnames = sorted(glob(op.join(percentile_samples_seg_path, \"*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii\")))\n", "percentile_samples_seg_fnames = zip(percentile_samples_seg_lh_fnames, percentile_samples_seg_rh_fnames)\n", "\n", - "plot_gradient(data_dir, percentile_samples_seg_fnames, cmap=\"afmhot\", color_range=(-0.7,.9), out_dir=op.join(figures_dir, \"Fig\", \"silhouette\"))" + "plot_gradient(data_dir, percentile_samples_seg_fnames, cmap=\"afmhot\", color_range=(-0.5,.5), out_dir=op.join(figures_dir, \"Fig\", \"silhouette\"))" ] }, { @@ -471,12 +540,12 @@ "metadata": {}, "outputs": [], "source": [ - "kmeans_samples_seg_path = op.join(results_dir, \"segmentation\", \"silhouette\", \"kmeans\")\n", + "kmeans_samples_seg_path = op.join(results_dir, \"segmentation\", \"KMeans_confidence-maps\")\n", "kmeans_samples_seg_lh_fnames = sorted(glob(op.join(kmeans_samples_seg_path, \"*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii\")))\n", "kmeans_samples_seg_rh_fnames = sorted(glob(op.join(kmeans_samples_seg_path, \"*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii\")))\n", "kmeans_samples_seg_fnames = zip(kmeans_samples_seg_lh_fnames, kmeans_samples_seg_rh_fnames)\n", "\n", - "plot_gradient(data_dir, kmeans_samples_seg_fnames, cmap=\"afmhot\", color_range=(-0.7,.9), out_dir=op.join(figures_dir, \"Fig\", \"silhouette\"))" + "plot_gradient(data_dir, kmeans_samples_seg_fnames, cmap=\"afmhot\", color_range=(-0.5,.5), out_dir=op.join(figures_dir, \"Fig\", \"silhouette\"))" ] }, { @@ -495,12 +564,12 @@ "source": [ "data_dir = \"../data\"\n", "\n", - "kde_samples_seg_path = op.join(results_dir, \"segmentation\", \"silhouette\", \"kde\")\n", + "kde_samples_seg_path = op.join(results_dir, \"segmentation\", \"KDE_confidence-maps\")\n", "kde_samples_seg_lh_fnames = sorted(glob(op.join(kde_samples_seg_path, \"*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-L_feature.func.gii\")))\n", "kde_samples_seg_rh_fnames = sorted(glob(op.join(kde_samples_seg_path, \"*desc-SilhouetteSamples_space-fsLR_den-32k_hemi-R_feature.func.gii\")))\n", "kde_samples_seg_fnames = zip(kde_samples_seg_lh_fnames, kde_samples_seg_rh_fnames)\n", "\n", - "plot_gradient(data_dir, kde_samples_seg_fnames, cmap=\"afmhot\", color_range=(-0.7,.9), out_dir=op.join(figures_dir, \"Fig\", \"silhouette\"))" + "plot_gradient(data_dir, kde_samples_seg_fnames, cmap=\"afmhot\", color_range=(-0.5,.5), out_dir=op.join(figures_dir, \"Fig\", \"silhouette\"))" ] }, { @@ -520,12 +589,12 @@ "img_cbar = op.join(figures_dir, \"Fig\", \"silhouette\", \"silhouette_samples_cbar.png\")\n", "\n", "img_lbs = [\"PCT\", \"KMeans\", \"KDE\"]\n", - "pct_files = sorted(glob(op.join(figures_dir, \"Fig\", \"silhouette\", \"Percentile*-SilhouetteSamples.tiff\")))\n", + "pct_files = sorted(glob(op.join(figures_dir, \"Fig\", \"silhouette\", \"PCT*-SilhouetteSamples.tiff\")))\n", "kms_files = sorted(glob(op.join(figures_dir, \"Fig\", \"silhouette\", \"KMeans*-SilhouetteSamples.tiff\")))\n", "kde_files = sorted(glob(op.join(figures_dir, \"Fig\", \"silhouette\", \"KDE*-SilhouetteSamples.tiff\")))\n", "step = 0\n", "row = 0\n", - "for segment_size, (pct_file, kms_file, kde_file) in enumerate(zip(pct_files, kms_files, kde_files), start=3):\n", + "for segment_size, (pct_file, kms_file, kde_file) in enumerate(zip(pct_files, kms_files, kde_files), start=2):\n", " add_title = False\n", " if step % 5 == 0:\n", " add_title = True\n", @@ -566,7 +635,7 @@ " ax.imshow(img)\n", " ax.set_axis_off()\n", "\n", - " if row == 4:\n", + " if row == 4 or segment_size == 32:\n", " row = 0\n", " fig.tight_layout(pad=0.1, w_pad=0.1)\n", " # plt.subplots_adjust(top=0.95)\n", diff --git a/figures/07_decoding.ipynb b/figures/07_decoding.ipynb index 8ba9393..b48bddc 100644 --- a/figures/07_decoding.ipynb +++ b/figures/07_decoding.ipynb @@ -40,20 +40,20 @@ "outputs": [], "source": [ "hue_order = [\n", - " 'term_neurosynth_Percentile',\n", - " 'term_neuroquery_Percentile',\n", + " 'term_neurosynth_PCT',\n", + " 'term_neuroquery_PCT',\n", " 'term_neurosynth_KMeans', \n", " \"term_neuroquery_KMeans\", \n", " \"term_neurosynth_KDE\", \n", " \"term_neuroquery_KDE\",\n", - " 'lda_neurosynth_Percentile',\n", - " 'lda_neuroquery_Percentile',\n", + " 'lda_neurosynth_PCT',\n", + " 'lda_neuroquery_PCT',\n", " 'lda_neurosynth_KMeans', \n", " \"lda_neuroquery_KMeans\", \n", " \"lda_neurosynth_KDE\", \n", " \"lda_neuroquery_KDE\",\n", - " 'gclda_neurosynth_Percentile',\n", - " 'gclda_neuroquery_Percentile',\n", + " 'gclda_neurosynth_PCT',\n", + " 'gclda_neuroquery_PCT',\n", " 'gclda_neurosynth_KMeans', \n", " \"gclda_neuroquery_KMeans\", \n", " \"gclda_neurosynth_KDE\", \n", @@ -67,14 +67,14 @@ "metadata": {}, "outputs": [], "source": [ - "methods = [\"Percentile\", \"KMeans\", \"KDE\"]\n", + "methods = [\"PCT\", \"KMeans\", \"KDE\"]\n", "dset_names = [\"neurosynth\", \"neuroquery\"]\n", "models = [\"term\", \"lda\", \"gclda\"]\n", "\n", "hight = 15\n", "method_nm_lst, seg_sol_lst, corr_val_lst, pval_val_lst, data_df_lst = [], [], [], [], []\n", "data_df_lst = []\n", - "for seg_sol in range(3, 33):\n", + "for seg_sol in range(2, 33):\n", " temp_data_df_lst = []\n", " for dset_name, model, method in itertools.product(dset_names, models, methods):\n", " corr_dir = op.join(dec_data_dir, f\"{dset_name}_{model}_corr_{method}\")\n", @@ -109,12 +109,13 @@ "dy = \"corr\"\n", "dx = \"method\"\n", "file_lbs = [\"uncorrected\", \"corrected\"]\n", - "for seg_sol in range(5, 6):\n", - " data_df = data_df_lst[seg_sol-3]\n", + "for seg_sol in range(2, 33):\n", + " data_df = data_df_lst[seg_sol-2]\n", " for seg_id in range(1, seg_sol+1):\n", " sub_data_uncorr_df = data_df.query(f'seg_sol == {seg_sol} & seg_id == {seg_id}')\n", " sub_data_corr_df = data_df.query(f'seg_sol == {seg_sol} & seg_id == {seg_id} & pval < 0.05')\n", " for file_lb, sub_data_df in zip(file_lbs, [sub_data_uncorr_df, sub_data_corr_df]):\n", + " #for file_lb, sub_data_df in zip(file_lbs, [sub_data_uncorr_df]):\n", " fig, ax = plt.subplots(1, 1)\n", " fig.set_size_inches(15, 2.5)\n", "\n", @@ -195,16 +196,20 @@ "outputs": [], "source": [ "n_cols = 1\n", - "n_rows = 5\n", + "n_rows = 8\n", "w = 7.5\n", - "h = 5\n", + "h = 9\n", "\n", "img_lbs = [\"PCT\", \"KMeans\", \"KDE\"]\n", "step = 0\n", "row = 0\n", - "fig_i = 66\n", - "for segment_size in range(32, 33):\n", - " for segment_id in range(28, segment_size+1):\n", + "fig_i = 0\n", + "for segment_size in range(2, 33):\n", + " for segment_id in range(1, segment_size+1):\n", + " if fig_i == 65:\n", + " n_rows = 7\n", + " h = 8\n", + " \n", " if row == 0:\n", " out_file = op.join(figures_dir, \"Fig\", \"decoding\", f\"{fig_i:02d}_distributions-corrected.eps\")\n", " print(f\"\\includegraphics[scale=1]{{{fig_i:02d}_distributions-corrected.eps}}\\n\")\n", @@ -272,16 +277,20 @@ "outputs": [], "source": [ "n_cols = 1\n", - "n_rows = 5\n", + "n_rows = 8\n", "w = 7.5\n", - "h = 5\n", + "h = 9\n", "\n", "img_lbs = [\"PCT\", \"KMeans\", \"KDE\"]\n", "step = 0\n", "row = 0\n", - "fig_i = 66\n", - "for segment_size in range(32, 33):\n", - " for segment_id in range(28, segment_size+1):\n", + "fig_i = 0\n", + "for segment_size in range(2, 33):\n", + " for segment_id in range(1, segment_size+1):\n", + " if fig_i == 65:\n", + " n_rows = 7\n", + " h = 8\n", + "\n", " if row == 0:\n", " out_file = op.join(figures_dir, \"Fig\", \"decoding\", f\"{fig_i:02d}_distributions-uncorrected.eps\")\n", " print(f\"\\includegraphics[scale=1]{{{fig_i:02d}_distributions-uncorrected.eps}}\\n\")\n", diff --git a/figures/08_performance-profiles.ipynb b/figures/08_performance-profiles.ipynb index 90006f1..4a43fbf 100644 --- a/figures/08_performance-profiles.ipynb +++ b/figures/08_performance-profiles.ipynb @@ -21,16 +21,57 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_profile(data_df, metric, hue_order, cmap=\"tab20\"):\n", + "from nimare.reports.figures import _reorder_matrix\n", + "# ['single', 'complete', 'average', 'weighted', 'ward']\n", + "def reorder_matrix(temp_data_df, flip_rows=False, flip_cols=False):\n", + " mat = temp_data_df.to_numpy()\n", + " row_labels, col_labels = (\n", + " temp_data_df.index.to_list(),\n", + " temp_data_df.columns.to_list(),\n", + " )\n", + " new_mat, new_row_labels, new_col_labels = _reorder_matrix(\n", + " mat,\n", + " row_labels,\n", + " col_labels,\n", + " \"complete\",\n", + " )\n", + " if flip_rows:\n", + " new_mat = new_mat[::-1, :]\n", + " new_row_labels = new_row_labels[::-1]\n", + " if flip_cols:\n", + " new_mat = new_mat[:, ::-1]\n", + " new_col_labels = new_col_labels[::-1]\n", + " return pd.DataFrame(new_mat, columns=new_col_labels, index=new_row_labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bin_df(df, percentile_threshold=90):\n", + " threshold_value = df.stack().quantile(percentile_threshold / 100)\n", + " return df.applymap(lambda x: 1 if x > threshold_value else 0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_profile(temp_data_df, metric, hue_order, cmap=\"tab20\"):\n", " sns.set(style=\"whitegrid\")\n", " \n", - " n_segments = 30\n", + " n_segments = 31\n", " for seg_sol in range(n_segments):\n", + " n_seg = seg_sol + 2\n", " fontsize=14\n", " fig, ax = plt.subplots(1, 1)\n", " fig.set_size_inches(9 + seg_sol*0.2, 4)\n", "\n", - " test_df = data_df[data_df[\"segment_solution\"] == seg_sol + 3]\n", + " test_df = temp_data_df[temp_data_df[\"segment_solution\"] == n_seg]\n", " test_df = test_df.reset_index()\n", " test_df[\"segment\"] = test_df[\"segment\"].astype(str)\n", "\n", @@ -97,7 +138,7 @@ " elif metric == \"tfidf\":\n", " ax.set_ylabel('TFIDF', fontsize=fontsize)\n", "\n", - " ax.set_title(f\"Segment Solution {seg_sol + 3:02d}\", fontsize=fontsize)\n", + " ax.set_title(f\"Segment Solution {n_seg:02d}\", fontsize=fontsize)\n", " fig.tight_layout()\n", " plt.savefig(op.join(\"./Fig\", \"performance\", f\"{metric}_profile_{seg_sol}.eps\"), bbox_inches=\"tight\")\n", " plt.close()\n", @@ -126,13 +167,14 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_mean_profile(data_df, metric, hue_order, cmap=\"tab20\"):\n", + "def plot_mean_profile(temp_data_df, metric, hue_order, cmap=\"tab20\"):\n", + " temp_data_df[[\"segment_solution\"]] = temp_data_df[[\"segment_solution\"]].astype(str)\n", " sns.set(style=\"white\")\n", " fig, ax = plt.subplots(1, 1)\n", " fig.set_size_inches(3, 15)\n", "\n", " sns.lineplot(\n", - " data=data_df,\n", + " data=temp_data_df,\n", " x=metric,\n", " y=\"segment_solution\",\n", " palette=cmap,\n", @@ -151,8 +193,8 @@ " if metric == \"max_corr\":\n", " fontsize = 12\n", " ax.set_xlabel('Mean Correlation Coefficient', fontsize=fontsize)\n", - " ax.set_xticks([0.1, 0.2, 0.3, 0.4, 0.5])\n", - " ax.set_xticklabels([0.1, 0.2, 0.3, 0.4, 0.5], fontsize=fontsize)\n", + " ax.set_xticks([0.1, 0.3, 0.5, 0.7, 0.9])\n", + " ax.set_xticklabels([0.1, 0.3, 0.5, 0.7, 0.9], fontsize=fontsize)\n", " elif metric == \"ic\":\n", " fontsize = 16\n", " ax.set_xlabel('Information Content', fontsize=fontsize, labelpad=10)\n", @@ -238,25 +280,59 @@ "figure_dir = op.abspath(\"./Fig\")\n", "\n", "hue_order = [\n", - " 'term_neurosynth_Percentile',\n", - " 'term_neuroquery_Percentile',\n", + " 'term_neurosynth_PCT',\n", + " 'term_neuroquery_PCT',\n", " 'term_neurosynth_KMeans', \n", " \"term_neuroquery_KMeans\", \n", " \"term_neurosynth_KDE\", \n", " \"term_neuroquery_KDE\",\n", - " 'lda_neurosynth_Percentile',\n", - " 'lda_neuroquery_Percentile',\n", + " 'lda_neurosynth_PCT',\n", + " 'lda_neuroquery_PCT',\n", " 'lda_neurosynth_KMeans', \n", " \"lda_neuroquery_KMeans\", \n", " \"lda_neurosynth_KDE\", \n", " \"lda_neuroquery_KDE\",\n", - " 'gclda_neurosynth_Percentile',\n", - " 'gclda_neuroquery_Percentile',\n", + " 'gclda_neurosynth_PCT',\n", + " 'gclda_neuroquery_PCT',\n", " 'gclda_neurosynth_KMeans', \n", " \"gclda_neuroquery_KMeans\", \n", " \"gclda_neurosynth_KDE\", \n", " \"gclda_neuroquery_KDE\",\n", - "]" + "]\n", + "\n", + "method_order = [\"PCT\", \"KMeans\", \"KDE\"]\n", + "\n", + "component_order = [\"G1\", \"G1:G2\", \"G1:G3\", \"G1:G4\", \"G1:G5\", \"G1:G6\", \"G1:G7\", \"G1:G8\", \"G1:G9\"]\n", + "\n", + "model_dict = {\n", + " 'term_neurosynth_PCT': \"NS-TERM-PCT\",\n", + " 'term_neuroquery_PCT': \"NQ-TERM-PCT\",\n", + " 'term_neurosynth_KMeans': \"NS-TERM-KMeans\", \n", + " \"term_neuroquery_KMeans\": \"NQ-TERM-KMeans\", \n", + " \"term_neurosynth_KDE\": \"NS-TERM-KDE\", \n", + " \"term_neuroquery_KDE\": \"NQ-TERM-KDE\",\n", + " 'lda_neurosynth_PCT': \"NS-LDA-PCT\",\n", + " 'lda_neuroquery_PCT': \"NQ-LDA-PCT\",\n", + " 'lda_neurosynth_KMeans': \"NS-LDA-KMeans\", \n", + " \"lda_neuroquery_KMeans\": \"NQ-LDA-KMeans\", \n", + " \"lda_neurosynth_KDE\": \"NS-LDA-KDE\", \n", + " \"lda_neuroquery_KDE\": \"NQ-LDA-KDE\",\n", + " 'gclda_neurosynth_PCT': \"NS-GCLDA-PCT\",\n", + " 'gclda_neuroquery_PCT': \"NQ-GCLDA-PCT\",\n", + " 'gclda_neurosynth_KMeans': \"NS-GCLDA-KMeans\", \n", + " \"gclda_neuroquery_KMeans\": \"NQ-GCLDA-KMeans\", \n", + " \"gclda_neurosynth_KDE\": \"NS-GCLDA-KDE\", \n", + " \"gclda_neuroquery_KDE\": \"NQ-GCLDA-KDE\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(np.mean([0.739303,0.702998]), np.std([0.739303,0.702998]))" ] }, { @@ -269,6 +345,46 @@ "data_df" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "suub_data_df = data_df[(data_df[\"segment_solution\"] == data_df[\"segment\"]) | (data_df[\"segment\"] == 1)]\n", + "suub_data_df.head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for segm in [\"PCT\", \"KMeans\", \"KDE\"]:\n", + " ic_lst = []\n", + " tfidf_lst = []\n", + " corr_lst = []\n", + " for model in [\"term\", \"lda\", \"gclda\"]:\n", + " for dset in [\"neurosynth\", \"neuroquery\"]:\n", + " model_nm = f\"{model}_{dset}_{segm}\"\n", + " ic_lst.append(data_df.loc[data_df[\"method\"] == model_nm , \"information_content\"].to_list())\n", + " tfidf_lst.append(data_df.loc[data_df[\"method\"] == model_nm, \"tfidf\"].to_list())\n", + " # coor = data_df.loc[(data_df[\"method\"] == model_nm) & (data_df[\"segment_solution\"] == 3), \"max_corr\"].to_list()\n", + " coor = data_df.loc[(data_df[\"method\"] == model_nm), \"max_corr\"].to_list()\n", + " \n", + " corr_lst.append(coor)\n", + "\n", + " # print(model_nm, \"Corr\", np.mean(coor), np.std(coor))\n", + " \n", + " ic_arr = np.hstack(ic_lst)\n", + " tfidf_arr = np.hstack(tfidf_lst)\n", + " corr_arr = np.hstack(corr_lst)\n", + " print(segm, \"Corr\", corr_arr.mean(), corr_arr.std())\n", + " # print(model, \"IC\", ic_arr.mean(), ic_arr.std())\n", + " # print(model, \"TFIDF\", tfidf_arr.mean(), tfidf_arr.std())" + ] + }, { "cell_type": "code", "execution_count": null, @@ -276,10 +392,478 @@ "outputs": [], "source": [ "mean_data_df = pd.read_csv(op.join(result_dir, \"performance\", \"performance_average.tsv\"), delimiter=\"\\t\")\n", - "mean_data_df[\"segment_solution\"] = mean_data_df[\"segment_solution\"].astype(str)\n", + "# mean_data_df[\"segment_solution\"] = mean_data_df[\"segment_solution\"].astype(str)\n", "mean_data_df" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ld_scores_df = pd.read_csv(op.join(result_dir, \"segmentation\", \"scores_uni-dimensional.csv\"))\n", + "ld_scores_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hd_scores_df = pd.read_csv(op.join(result_dir, \"segmentation\", \"scores_high-dimensional.csv\"))\n", + "hd_scores_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sns.set(style=\"whitegrid\")\n", + "from matplotlib.patches import Rectangle\n", + "\n", + "fig, axes = plt.subplots(1, 2)\n", + "fig.set_size_inches(8, 3)\n", + "\n", + "sns.regplot(\n", + " data=data_df, \n", + " x=\"segment_size\", \n", + " y=\"max_corr\",\n", + " logx=True, \n", + " scatter_kws={\"s\": 10},\n", + " line_kws={\"color\": \"r\"}, \n", + " ax=axes[0],\n", + ")\n", + "sns.regplot(\n", + " data=suub_data_df, \n", + " x=\"segment_size\", \n", + " y=\"max_corr\", \n", + " logx=True, \n", + " scatter_kws={\"s\": 10}, \n", + " line_kws={\"color\":\"r\"},\n", + " ax=axes[1],\n", + ")\n", + "\n", + "axes[0].set_xlabel(\"Segment Size\", fontsize=12)\n", + "axes[0].set_ylabel(\"Max Correlation Coefficient\", fontsize=12)\n", + "axes[0].set_title(\"All Segments\", fontsize=12)\n", + "axes[0].set_xlim(1, 42000)\n", + "axes[0].set_ylim(0, 0.8)\n", + "axes[1].set_xlabel(\"Segment Size\", fontsize=12)\n", + "axes[1].set_ylabel(\"\")\n", + "axes[1].set_yticklabels([])\n", + "axes[1].set_title(\"End Segments\", fontsize=12)\n", + "axes[1].set_xlim(1, 42000)\n", + "axes[1].set_ylim(0, 0.8)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-S7.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "glue_corr = mean_data_df.pivot(index=\"method\", columns=\"segment_solution\", values=\"max_corr\")\n", + "glue_ic = mean_data_df.pivot(index=\"method\", columns=\"segment_solution\", values=\"ic\")\n", + "glue_tfidf = mean_data_df.pivot(index=\"method\", columns=\"segment_solution\", values=\"tfidf\")\n", + "glue_snr = mean_data_df.pivot(index=\"method\", columns=\"segment_solution\", values=\"snr\")\n", + "glue_silhouette = ld_scores_df.pivot(index=\"method\", columns=\"segment\", values=\"silhouette\")\n", + "glue_variance = ld_scores_df.pivot(index=\"method\", columns=\"segment\", values=\"variance_ratio\")\n", + "glue_separation = ld_scores_df.pivot(index=\"method\", columns=\"segment\", values=\"cluster_separation\")\n", + "glue_high_silhouette = hd_scores_df.pivot(index=\"component\", columns=\"segment\", values=\"silhouette\")\n", + "glue_high_variance = hd_scores_df.pivot(index=\"component\", columns=\"segment\", values=\"variance_ratio\")\n", + "glue_high_separation = hd_scores_df.pivot(index=\"component\", columns=\"segment\", values=\"cluster_separation\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "percent = 70\n", + "# glue_corr_sorted = reorder_matrix(glue_corr, flip_rows=False, flip_cols=False)\n", + "glue_corr_sorted = glue_corr.reindex(hue_order)\n", + "glue_corr_sorted.rename(index=model_dict, inplace=True)\n", + "glue_corr_sorted_bin = bin_df(glue_corr_sorted, 90)\n", + "\n", + "# glue_ic_sorted = reorder_matrix(glue_ic, flip_rows=False, flip_cols=True)\n", + "glue_ic_sorted = glue_ic.reindex(hue_order)\n", + "glue_ic_sorted.rename(index=model_dict, inplace=True)\n", + "glue_ic_sorted_bin = bin_df(glue_ic_sorted, percent)\n", + "\n", + "# glue_tfidf_sorted = reorder_matrix(glue_tfidf, flip_rows=True, flip_cols=True)\n", + "glue_tfidf_sorted = glue_tfidf.reindex(hue_order)\n", + "glue_tfidf_sorted.rename(index=model_dict, inplace=True)\n", + "glue_tfidf_sorted_bin = bin_df(glue_tfidf_sorted, percent)\n", + "\n", + "#glue_snr_sorted = reorder_matrix(glue_snr, flip_rows=True, flip_cols=True)\n", + "glue_snr_sorted = glue_snr.reindex(hue_order)\n", + "glue_snr_sorted.rename(index=model_dict, inplace=True)\n", + "glue_snr_sorted_bin = bin_df(glue_snr_sorted, 90)\n", + "\n", + "glue_silhouette_sorted = glue_silhouette.reindex(method_order)\n", + "glue_silhouette_sorted_bin = bin_df(glue_silhouette, 98)\n", + "# glue_silhouette_sorted_bin = (glue_silhouette_sorted == glue_silhouette_sorted.max().max()).astype(int)\n", + "# top_two = glue_silhouette_sorted.unstack().nlargest(2).index\n", + "# glue_silhouette_sorted_bin = glue_silhouette_sorted.copy()\n", + "# glue_silhouette_sorted_bin[:] = 0\n", + "# glue_silhouette_sorted_bin.loc[top_two] = 1\n", + "\n", + "glue_variance_sorted = glue_variance.reindex(method_order)\n", + "glue_variance_sorted_bin = bin_df(glue_variance, 98)\n", + "# glue_variance_sorted_bin = (glue_variance_sorted == glue_variance_sorted.max().max()).astype(int)\n", + "# top_two = glue_variance_sorted.unstack().nlargest(2).index\n", + "# glue_variance_sorted_bin = glue_variance_sorted.copy()\n", + "# glue_variance_sorted_bin[:] = 0\n", + "# glue_variance_sorted_bin.loc[top_two] = 1\n", + "\n", + "glue_separation_sorted = glue_separation.reindex(method_order)\n", + "glue_separation_sorted_bin = bin_df(glue_separation*-1, 98)\n", + "# glue_separation_sorted_bin = (glue_separation_sorted == glue_separation_sorted.min().min()).astype(int)\n", + "# top_two = glue_separation_sorted.unstack().nlargest(2).index\n", + "# glue_separation_sorted_bin = glue_separation_sorted.copy()\n", + "# glue_separation_sorted_bin[:] = 0\n", + "# glue_separation_sorted_bin.loc[top_two] = 1\n", + "\n", + "glue_high_silhouette_sorted = glue_high_silhouette.reindex(component_order)\n", + "glue_high_silhouette_sorted_bin = bin_df(glue_high_silhouette, 98)\n", + "\n", + "glue_high_variance_sorted = glue_high_variance.reindex(component_order)\n", + "glue_high_variance_sorted_bin = bin_df(glue_high_variance, 98)\n", + "\n", + "glue_high_separation_sorted = glue_high_separation.reindex(component_order)\n", + "glue_high_separation_sorted_bin = bin_df(glue_high_separation*-1, 98)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#sns.set(style=\"whitegrid\", font_scale=0.8)\n", + "sns.set(style=\"whitegrid\")\n", + "from matplotlib.patches import Rectangle\n", + "\n", + "titles = [\"Mean Silhouette Coefficient\", \"Variance Ratio\", \"Cluster Separation\"]\n", + "fig, axes = plt.subplots(3, 1)\n", + "fig.set_size_inches(8, 8)\n", + "\n", + "data_df_lst = [glue_high_silhouette_sorted, glue_high_variance_sorted, glue_high_separation_sorted]\n", + "data_bin_df_lst = [glue_high_silhouette_sorted_bin, glue_high_variance_sorted_bin, glue_high_separation_sorted_bin]\n", + "for met_i, (dat_df, data_bin_df, title) in enumerate(zip(data_df_lst, data_bin_df_lst, titles)):\n", + " ax = axes[met_i]\n", + "\n", + " if met_i == 2:\n", + " sns.heatmap(dat_df, cmap=\"Blues\", yticklabels=True, vmax=3, ax=ax)\n", + " else:\n", + " sns.heatmap(dat_df, cmap=\"Blues\", yticklabels=True, ax=ax)\n", + " # ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=16)\n", + " x_lab = [lab if int(lab.get_text()) % 2 == 0 else \"\" for lab in ax.get_xticklabels()]\n", + " ax.set_xticklabels(x_lab, rotation=0, fontsize=14) \n", + " ax.set_yticklabels(ax.get_yticklabels(), rotation=0)\n", + " \n", + " non_zero_indices = np.nonzero(data_bin_df)\n", + " for i in range(len(non_zero_indices[0])):\n", + " ax.add_patch(Rectangle((non_zero_indices[1][i], non_zero_indices[0][i]), 1, 1, fill=False, edgecolor='red', lw=1))\n", + "\n", + " # square=True\n", + " ax.set_ylabel(\"\")\n", + " ax.set_title(title, fontsize=16)\n", + " if met_i == 2:\n", + " ax.set_xlabel('Segment Solution', fontsize=16)\n", + " # ax.set_xticklabels(np.arange(2, 33, 2))\n", + " ax.tick_params(axis='x', labelsize=14)\n", + " else:\n", + " ax.set_xlabel(\"\")\n", + " ax.set_xticklabels([])\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-S10a.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import reduce\n", + "clust_high_df_sum = reduce(lambda x, y: x.add(y, fill_value=0), [glue_high_silhouette_sorted_bin, glue_high_variance_sorted_bin, glue_high_separation_sorted_bin])\n", + "clust_high_df_sum" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.colors as colors\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "fig.set_size_inches(8, 3)\n", + "\n", + "n_bins = clust_high_df_sum.max().max() + 1\n", + "vals = np.arange(n_bins)\n", + "vals_ticks = vals + 0.5\n", + "vals_labels = [str(lab) for lab in vals]\n", + "cmap = plt.cm.get_cmap('Blues', n_bins)\n", + "norm = colors.BoundaryNorm(np.arange(n_bins+1), cmap.N)\n", + "\n", + "sns.heatmap(\n", + " clust_high_df_sum, \n", + " cmap=cmap,\n", + " vmin=0, \n", + " vmax=3, \n", + " xticklabels=True, \n", + " yticklabels=True, \n", + " ax=ax\n", + ")\n", + "x_lab = [lab if int(lab.get_text()) % 2 == 0 else \"\" for lab in ax.get_xticklabels()]\n", + "ax.set_xticklabels(x_lab, rotation=0, fontsize=14)\n", + "ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=14)\n", + " \n", + "non_zero_indices = np.nonzero(data_bin_df)\n", + "ax.add_patch(Rectangle((2, 6), 1, 1, fill=False, edgecolor='red', lw=3))\n", + "\n", + "ax.set_ylabel(\"\")\n", + "ax.set_xlabel('Segment Solution', fontsize=16)\n", + "colorbar = ax.collections[0].colorbar\n", + "colorbar.set_ticks(vals_ticks)\n", + "colorbar.set_ticklabels(vals_labels)\n", + "ax.set_title(\"Overall Performance\", fontsize=16)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-S10b.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#sns.set(style=\"whitegrid\", font_scale=0.8)\n", + "sns.set(style=\"whitegrid\")\n", + "from matplotlib.patches import Rectangle\n", + "\n", + "titles = [\"Mean Silhouette Coefficient\", \"Variance Ratio\", \"Cluster Separation\"]\n", + "fig, axes = plt.subplots(3, 1)\n", + "fig.set_size_inches(8, 4)\n", + "\n", + "data_df_lst = [glue_silhouette_sorted, glue_variance_sorted, glue_separation_sorted]\n", + "data_bin_df_lst = [glue_silhouette_sorted_bin, glue_variance_sorted_bin, glue_separation_sorted_bin]\n", + "for met_i, (dat_df, data_bin_df, title) in enumerate(zip(data_df_lst, data_bin_df_lst, titles)):\n", + " ax = axes[met_i]\n", + "\n", + " sns.heatmap(dat_df, cmap=\"Blues\", yticklabels=True, ax=ax)\n", + " # ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=16)\n", + " x_lab = [lab if int(lab.get_text()) % 2 == 0 else \"\" for lab in ax.get_xticklabels()]\n", + " ax.set_xticklabels(x_lab, rotation=0, fontsize=14) \n", + " ax.set_yticklabels(ax.get_yticklabels(), rotation=0)\n", + " \n", + " non_zero_indices = np.nonzero(data_bin_df)\n", + " for i in range(len(non_zero_indices[0])):\n", + " ax.add_patch(Rectangle((non_zero_indices[1][i], non_zero_indices[0][i]), 1, 1, fill=False, edgecolor='red', lw=1))\n", + "\n", + " # square=True\n", + " ax.set_ylabel(\"\")\n", + " ax.set_title(title, fontsize=16)\n", + " if met_i == 2:\n", + " ax.set_xlabel('Segment Solution', fontsize=16)\n", + " # ax.set_xticklabels(np.arange(2, 33, 2))\n", + " ax.tick_params(axis='x', labelsize=14)\n", + " else:\n", + " ax.set_xlabel(\"\")\n", + " ax.set_xticklabels([])\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-S08a.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import reduce\n", + "clust_df_sum = reduce(lambda x, y: x.add(y, fill_value=0), [glue_silhouette_sorted_bin, glue_variance_sorted_bin, glue_separation_sorted_bin])\n", + "clust_df_sum" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.colors as colors\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "fig.set_size_inches(8, 2)\n", + "\n", + "n_bins = clust_df_sum.max().max() + 1\n", + "vals = np.arange(n_bins)\n", + "vals_ticks = vals + 0.5\n", + "vals_labels = [str(lab) for lab in vals]\n", + "cmap = plt.cm.get_cmap('Blues', n_bins)\n", + "norm = colors.BoundaryNorm(np.arange(n_bins+1), cmap.N)\n", + "\n", + "sns.heatmap(\n", + " clust_df_sum, \n", + " cmap=cmap,\n", + " vmin=0, \n", + " vmax=4, \n", + " xticklabels=True, \n", + " yticklabels=True, \n", + " ax=ax\n", + ")\n", + "x_lab = [lab if int(lab.get_text()) % 2 == 0 else \"\" for lab in ax.get_xticklabels()]\n", + "ax.set_xticklabels(x_lab, rotation=0, fontsize=14)\n", + "ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=14)\n", + " \n", + "non_zero_indices = np.nonzero(data_bin_df)\n", + "ax.add_patch(Rectangle((0, 1), 1, 1, fill=False, edgecolor='red', lw=3))\n", + "\n", + "ax.set_ylabel(\"\")\n", + "ax.set_xlabel('Segment Solution', fontsize=16)\n", + "colorbar = ax.collections[0].colorbar\n", + "colorbar.set_ticks(vals_ticks)\n", + "colorbar.set_ticklabels(vals_labels)\n", + "ax.set_title(\"Overall Performance\", fontsize=16)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-S08b.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#sns.set(style=\"whitegrid\", font_scale=0.8)\n", + "sns.set(style=\"whitegrid\")\n", + "from matplotlib.patches import Rectangle\n", + "\n", + "titles = [\"Mean Correlation Coefficient\", \"Information Content\", \"Mean TFIDF\", \"Normalized SNR\"]\n", + "fig, axes = plt.subplots(2, 2)\n", + "fig.set_size_inches(15, 10)\n", + "\n", + "data_df_lst = [glue_corr_sorted, glue_ic_sorted, glue_tfidf_sorted, glue_snr_sorted]\n", + "data_bin_df_lst = [glue_corr_sorted_bin, glue_ic_sorted_bin, glue_tfidf_sorted_bin, glue_snr_sorted_bin]\n", + "for met_i, (data_df, data_bin_df, title) in enumerate(zip(data_df_lst, data_bin_df_lst, titles)):\n", + " \n", + " if met_i==0:\n", + " ax = axes[0,0]\n", + " elif met_i==1:\n", + " ax = axes[1,0]\n", + " elif met_i==2:\n", + " ax = axes[1,1]\n", + " elif met_i==3:\n", + " ax = axes[0,1]\n", + "\n", + " sns.heatmap(data_df, cmap=\"Blues\", yticklabels=True, ax=ax)\n", + " # ax.set_xticklabels(np.arrange(2,32,2), rotation=0, fontsize=14)\n", + " ax.set_yticklabels(ax.get_yticklabels(), rotation=0)\n", + " \n", + " non_zero_indices = np.nonzero(data_bin_df)\n", + " for i in range(len(non_zero_indices[0])):\n", + " ax.add_patch(Rectangle((non_zero_indices[1][i], non_zero_indices[0][i]), 1, 1, fill=False, edgecolor='red', lw=1))\n", + "\n", + " # square=True\n", + " ax.set_ylabel(\"\")\n", + " ax.set_title(title, fontsize=20)\n", + " if met_i == 1 or met_i == 2:\n", + " ax.set_xlabel('Segment Solution', fontsize=18)\n", + " ax.tick_params(axis='x', labelsize=16)\n", + " else:\n", + " ax.set_xlabel(\"\")\n", + " ax.set_xticklabels([])\n", + "\n", + " if met_i == 0 or met_i == 1:\n", + " ax.set_ylabel('')\n", + " # ax.tick_params(axis='y', labelsize=16)\n", + " else:\n", + " ax.set_ylabel('')\n", + " ax.set_yticklabels([])\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-08a.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use reduce to sum all DataFrames\n", + "from functools import reduce\n", + "df_sum = reduce(lambda x, y: x.add(y, fill_value=0), data_bin_df_lst)\n", + "df_sum[2] = df_sum[2] + [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.colors as colors\n", + "\n", + "fig, ax = plt.subplots(1, 1)\n", + "fig.set_size_inches(15, 6)\n", + "\n", + "n_bins = df_sum.max().max() + 1\n", + "vals = np.arange(n_bins)\n", + "vals_ticks = vals + 0.5\n", + "vals_labels = [str(lab) for lab in vals]\n", + "cmap = plt.cm.get_cmap('Blues', n_bins)\n", + "norm = colors.BoundaryNorm(np.arange(n_bins+1), cmap.N)\n", + "\n", + "sns.heatmap(\n", + " df_sum, \n", + " cmap=cmap,\n", + " vmin=0, \n", + " vmax=5, \n", + " xticklabels=True, \n", + " yticklabels=True, \n", + " ax=ax\n", + ")\n", + "ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize=16)\n", + "ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)\n", + " \n", + "non_zero_indices = np.nonzero(data_bin_df)\n", + "ax.add_patch(Rectangle((0, 9), 1, 1, fill=False, edgecolor='red', lw=3))\n", + "\n", + "ax.set_ylabel(\"\")\n", + "ax.set_xlabel('Segment Solution', fontsize=18)\n", + "colorbar = ax.collections[0].colorbar\n", + "colorbar.set_ticks(vals_ticks)\n", + "colorbar.set_ticklabels(vals_labels)\n", + "ax.set_title(\"Overall Performance\", fontsize=20)\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig(op.join(\"./Fig\", \"Fig-08b.png\"), dpi=600, bbox_inches=\"tight\")\n", + "plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -315,7 +899,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/figures/10_term-classification.ipynb b/figures/10_term-classification.ipynb index bfa7d55..6889c76 100644 --- a/figures/10_term-classification.ipynb +++ b/figures/10_term-classification.ipynb @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "methods = [\"Percentile\", \"KMeans\", \"KDE\"]\n", + "methods = [\"PCT\", \"KMeans\", \"KDE\"]\n", "dset_names = [\"neurosynth\", \"neuroquery\"]\n", "models = [\"term\", \"lda\", \"gclda\"]\n", "\n", @@ -99,7 +99,7 @@ " ]\n", " corr_idx = temp_df[\"corr_idx\"].values[0]\n", " corr_idx_str = f\"{corr_idx:04d}\" if model == \"term\" else f\"{corr_idx:03d}\"\n", - " \"\"\"\n", + "\n", " # Get gradient maps\n", " maps_fslr = _fetch_metamaps(dset_name, model, data_dir=data_dir)\n", " data = maps_fslr[corr_idx, :]\n", @@ -113,7 +113,7 @@ " prefix = f\"{segmentation:02d}-{seg_sol:02d}-{model_i}-{dset_i}-{iter_i}\"\n", " out_filename = op.join(map_out_path, f\"maps_{prefix}_{dset_name}-{model}-{method}.tiff\")\n", " plot_surf_maps(data_lh, data_rh, threshold_, range_, cmap, 100, data_dir, out_filename)\n", - " \"\"\"\n", + " \n", " feature = temp_df[\"features\"].values[0]\n", "\n", " if model != \"term\":\n", @@ -176,7 +176,7 @@ "source": [ "cotegories = np.array([\"Functional\", \"Clinical\", \"Anatomical\", \"Non-Specific\"])\n", "\n", - "methods = [\"Percentile\", \"KMeans\", \"KDE\"]\n", + "methods = [\"PCT\", \"KMeans\", \"KDE\"]\n", "dset_names = [\"neurosynth\", \"neuroquery\"]\n", "models = [\"term\", \"lda\", \"gclda\"]\n", "\n", @@ -185,7 +185,7 @@ "data_df = pd.read_csv(\"../results/performance/performance.tsv\", delimiter=\"\\t\")\n", "features_lst = []\n", "prefix_lst = []\n", - "for seg_sol in range(3,33):\n", + "for seg_sol in range(2,33):\n", " sub_class_lst = []\n", " seg_sol_lst = []\n", " data2plot_df = pd.DataFrame()\n", diff --git a/figures/11_decoding-fig.ipynb b/figures/11_decoding-fig.ipynb index 1733f72..f4c4886 100644 --- a/figures/11_decoding-fig.ipynb +++ b/figures/11_decoding-fig.ipynb @@ -6,131 +6,15 @@ "metadata": {}, "outputs": [], "source": [ - "import math\n", "import os.path as op\n", - "from ast import literal_eval\n", "import itertools\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", - "import numpy as np\n", "import pandas as pd\n", - "import matplotlib.cm as cm\n", - "\n", - "from utils import _get_twfrequencies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_radar(corrs, features, model, fig=None, ax=None, out_fig=None):\n", - " n_rows = 10 if len(corrs) > 10 else len(corrs)\n", - " angle_zero = 0\n", - " fontsize = 36\n", - " \n", - " corrs = corrs[:n_rows]\n", - " features = features[:n_rows]\n", - " angles = [(angle_zero + (n / float(n_rows) * 2 * np.pi)) for n in range(n_rows)]\n", - " if model == \"lda\" or model == \"gclda\":\n", - " features = [\"\\n\".join(feature.split(\"_\")[1:]).replace(\" \", \"\\n\") for feature in features]\n", - " else:\n", - " features = [feature.replace(\" \", \"\\n\") for feature in features]\n", - "\n", - " roundup_corr = math.ceil(corrs.max() * 10) / 10\n", - "\n", - " # Define color scheme\n", - " plt.rcParams[\"text.color\"] = \"#1f1f1f\"\n", - " cmap = cm.get_cmap(\"YlOrRd\")\n", - " norm = plt.Normalize(vmin=corrs.min(), vmax=corrs.max())\n", - " colors = cmap(norm(corrs))\n", - "\n", - " # Plot radar\n", - " if fig is None and ax is None:\n", - " fig, ax = plt.subplots(figsize=(9, 9), subplot_kw={\"projection\": \"polar\"})\n", - " \n", - " ax.set_theta_offset(0)\n", - " ax.set_ylim(-0.1, roundup_corr)\n", - "\n", - " ax.bar(angles, corrs, color=colors, alpha=0.9, width=0.52, zorder=10) \n", - " ax.vlines(angles, 0, roundup_corr, color=\"grey\", ls=(0, (4, 4)), zorder=11)\n", - "\n", - " ax.set_xticks(angles)\n", - " ax.set_xticklabels(features, size=fontsize, zorder=13)\n", - "\n", - " ax.xaxis.grid(False)\n", - "\n", - " step = 0.10000000000000009\n", - " yticks = np.round(np.arange(0, roundup_corr + step, step), 1)\n", - " ax.set_yticklabels([])\n", - " ax.set_yticks(yticks)\n", - "\n", - " ax.spines[\"start\"].set_color(\"none\")\n", - " ax.spines[\"polar\"].set_color(\"none\")\n", - "\n", - " xticks = ax.xaxis.get_major_ticks()\n", - " [xtick.set_pad(90) for xtick in xticks]\n", - "\n", - " sep = 0.06\n", - " [\n", - " ax.text(np.pi / 2, ytick - sep, f\"{ytick}\", ha=\"center\", size=fontsize-2, color=\"grey\", zorder=12) \n", - " for ytick in yticks\n", - " ]\n", - "\n", - " if out_fig is not None:\n", - " fig.savefig(out_fig, bbox_inches=\"tight\")\n", - " plt.close()\n", - " gc.collect()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from wordcloud import WordCloud\n", - "\n", - "def plot_cloud(features_list, frequencies, corrs, model, fig=None, ax=None, out_fig=None):\n", - " frequencies_dict = {}\n", - " if model == \"lda\" or model == \"gclda\":\n", - " for features, frequency, corr in zip(features_list, frequencies, corrs):\n", - " #frequency = literal_eval(frequency_str)\n", - " for word, freq in zip(features, frequency):\n", - " if word not in frequencies_dict:\n", - " frequencies_dict[word] = freq * corr\n", - " else:\n", - " for word, corr in zip(features_list, corrs):\n", - " if word not in frequencies_dict:\n", - " frequencies_dict[word] = corr\n", - " \n", - " dpi = 100\n", - " w = 9\n", - " h = 5\n", - " if fig is None and ax is None:\n", - " fig, ax = plt.subplots(figsize=(w, h))\n", - " \n", - " wc = WordCloud(\n", - " width=w * dpi,\n", - " height=h * dpi,\n", - " background_color=\"white\", \n", - " random_state=0, \n", - " colormap=\"YlOrRd\"\n", - " )\n", - " wc.generate_from_frequencies(frequencies=frequencies_dict)\n", - " ax.imshow(wc)\n", - " # ax.axis(\"off\")\n", - " ax.get_xaxis().set_ticks([])\n", - " ax.get_yaxis().set_ticks([])\n", - " for spine in ax.spines.values():\n", - " spine.set_visible(False)\n", - " \n", - " if out_fig is not None:\n", - " fig.savefig(out_fig, bbox_inches=\"tight\", dpi=dpi)\n", - " # plt.close()\n", - " gc.collect()" + "from gradec.fetcher import _fetch_features, _fetch_frequencies, _fetch_classification\n", + "from gradec.plot import plot_radar, plot_cloud\n", + "from gradec.utils import _decoding_filter" ] }, { @@ -139,54 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "methods = [\"Percentile\", \"KMeans\", \"KDE\"]\n", - "dset_names = [\"neurosynth\", \"neuroquery\"]\n", - "models = [\"term\", \"lda\", \"gclda\"]\n", - "\n", - "classifications_dir = op.join(\"../data/classification\")\n", - "models_dir = op.join(\"../data/models\")\n", - "\n", - "n_segmentations = 30\n", - "data_lst = []\n", - "for model, dset_name, method in itertools.product(models, dset_names, methods):\n", - " corr_dir = op.join(\"../results/decoding\", f\"{dset_name}_{model}_corr_{method}\")\n", - " \n", - " # Data for wordcloud\n", - " frequencies = (\n", - " _get_twfrequencies(dset_name, model, 3, models_dir)\n", - " if model in [\"lda\", \"gclda\"]\n", - " else None\n", - " )\n", - "\n", - " tmp_data_lst = []\n", - " for seg_sol in range(3, 33):\n", - " \n", - " corr_file = op.join(corr_dir, f\"corrs_{seg_sol:02d}.csv\")\n", - " pval_file = op.join(corr_dir, f\"pvals-FDR_{seg_sol:02d}.csv\")\n", - " corr_df = pd.read_csv(corr_file, index_col=\"feature\")\n", - " pval_df = pd.read_csv(pval_file, index_col=\"feature\")\n", - " features = corr_df.index.to_list()\n", - "\n", - " class_df = pd.read_csv(op.join(classifications_dir, f\"{model}_{dset_name}_classification.csv\"), index_col=\"FEATURE\")\n", - " if model == \"term\":\n", - " classification = [class_df.loc[[feature], \"Classification\"].values[0] if feature in class_df.index else \"Non-Specific\" for feature in features]\n", - " classification = classification * corr_df.shape[1]\n", - " else:\n", - " classification = class_df[\"Classification\"].to_list() * corr_df.shape[1]\n", - "\n", - " tmp_data_df = corr_df.melt(ignore_index=False).rename(columns={'variable': 'seg_id', \"value\": \"corr\"})\n", - " tmp_data_df['seg_id'] = tmp_data_df['seg_id'].astype(int) + 1\n", - " tmp_data_df[\"pval\"] = pval_df.melt(ignore_index=False)[\"value\"]\n", - "\n", - " tmp_data_df.insert(0, 'seg_sol', [seg_sol] * len(tmp_data_df))\n", - " tmp_data_df.insert(0, 'method', [f\"{model}_{dset_name}_{method}\"] * len(tmp_data_df))\n", - " tmp_data_df[\"classification\"] = classification\n", - "\n", - " if model in [\"lda\", \"gclda\"]:\n", - " tmp_data_df[\"frequencies\"] = frequencies * corr_df.shape[1]\n", - " \n", - " tmp_data_lst.append(tmp_data_df)\n", - " data_lst.append(tmp_data_lst)" + "data_dir = op.join(\"..\", \"data\")" ] }, { @@ -195,12 +32,12 @@ "metadata": {}, "outputs": [], "source": [ - "methods = [\"Percentile\", \"KMeans\", \"KDE\"]\n", + "methods = [\"PCT\", \"KMeans\", \"KDE\"]\n", "dset_names = [\"neurosynth\", \"neuroquery\"]\n", "models = [\"term\", \"lda\", \"gclda\"]\n", "\n", "label_dict = {\n", - " \"Percentile\": \"PCT\", \n", + " \"PCT\": \"PCT\", \n", " \"KMeans\": \"KMeans\", \n", " \"KDE\": \"KDE\", \n", " \"neurosynth\": \"NS\", \n", @@ -210,8 +47,10 @@ " \"gclda\": \"GCLDA\"\n", "}\n", "\n", - "for seg_sol in range(5, 6):\n", - " big_cloud_fn = op.join(\"./Fig\", \"survey\", f\"filtered-cloud_{seg_sol:02d}.png\")\n", + "for seg_sol in range(2, 33):\n", + " big_cloud_fn = op.join(\"./Fig\", \"survey\", f\"cloud_{seg_sol:02d}.png\")\n", + " print(f\"\\includegraphics[scale=1]{{cloud_{seg_sol:02d}.png}}\")\n", + "\n", " if not op.exists(big_cloud_fn):\n", " cloud_fig, cloud_axes_tpl = plt.subplots(18, seg_sol)\n", " cloud_fig.set_size_inches(1.6 * seg_sol, 15)\n", @@ -219,37 +58,69 @@ " # radar_fig, radar_axes_tpl = plt.subplots(18, seg_sol, subplot_kw={\"projection\": \"polar\"})\n", " # radar_fig.set_size_inches(1.6 * seg_sol, 15)\n", " for row_i, (model, dset_name, method) in enumerate(itertools.product(models, dset_names, methods)):\n", - " data_df = data_lst[row_i][seg_sol-3]\n", - "\n", - " for seg_id in range(1, seg_sol+1):\n", - " cloud_ax = cloud_axes_tpl[row_i, seg_id-1]\n", - " # radar_ax = radar_axes_tpl[row_i, seg_id-1]\n", - "\n", - " # data_df = data_df.rename(columns={ data_df.columns[0]: \"index\" })\n", - " # & pval < 0.05 only for las plot\n", - " # filtered_df = data_df.query(f'seg_sol == {seg_sol} & seg_id == {seg_id} & corr > 0 & classification == \"Functional\"')\n", - " filtered_df = data_df.query(f'seg_sol == {seg_sol} & seg_id == {seg_id} & corr > 0 & classification == \"Functional\" & pval < 0.05')\n", - " filtered_df = filtered_df.sort_values(by=['corr'], ascending=False)\n", - "\n", - " # Data for radar plot\n", - " corrs = filtered_df[\"corr\"].values\n", - " features = filtered_df.index.values\n", - " frequencies = filtered_df[\"frequencies\"].values if model in [\"lda\", \"gclda\"] else None\n", - "\n", - " features_split = [feature.split(\"_\")[1:] for feature in features] if model in [\"lda\", \"gclda\"] else features\n", - " \n", - " radar_fn = op.join(\"./Fig\", \"survey\", f\"filtered-radar_{model}-{dset_name}-{method}_{seg_sol:02d}-{seg_id:02d}.eps\")\n", - " cloud_fn = op.join(\"./Fig\", \"survey\", f\"filtered-cloud_{model}-{dset_name}-{method}_{seg_sol:02d}-{seg_id:02d}.png\")\n", - "\n", - " plot_radar(corrs, features, model, out_fig=radar_fn)\n", - " plot_cloud(features_split, frequencies, corrs, model, out_fig=cloud_fn)\n", + " corr_dir = op.join(\"../results/decoding\", f\"{dset_name}_{model}_corr_{method}\")\n", + " corr_file = op.join(corr_dir, f\"corrs_{seg_sol:02d}.csv\")\n", + " pval_file = op.join(corr_dir, f\"pvals-FDR_{seg_sol:02d}.csv\")\n", + " corr_df = pd.read_csv(corr_file, index_col=\"feature\")\n", + " pval_df = pd.read_csv(pval_file, index_col=\"feature\")\n", + " \n", + " # Load features for visualization\n", + " features = _fetch_features(dset_name, model, data_dir=data_dir)\n", + " classification, class_lst = _fetch_classification(dset_name, model, data_dir=data_dir)\n", + "\n", + " for seg in range(seg_sol):\n", + " seg_id = seg + 1\n", + " cloud_ax = cloud_axes_tpl[row_i, seg]\n", + "\n", + " data_df = corr_df[[f\"{seg}\"]]\n", + "\n", + " if model in [\"lda\", \"gclda\"]:\n", + " frequencies = _fetch_frequencies(dset_name, model, data_dir=data_dir)\n", + " filtered_df, filtered_features, filtered_frequencies = _decoding_filter(\n", + " data_df,\n", + " features,\n", + " classification,\n", + " freq_by_topic=frequencies,\n", + " class_by_topic=class_lst,\n", + " )\n", + " else:\n", + " filtered_df, filtered_features = _decoding_filter(\n", + " data_df,\n", + " features,\n", + " classification,\n", + " )\n", + " filtered_df.columns = [\"r\"]\n", + "\n", + " # Visualize results\n", + " corrs = filtered_df[\"r\"].to_numpy()\n", + "\n", + " # Word cloud plot\n", + " cloud_fn = op.join(\"./Fig\", \"survey\", f\"cloud_{model}-{dset_name}-{method}_{seg_sol:02d}-{seg_id:02d}.png\")\n", + " if model in [\"lda\", \"gclda\"]:\n", + " plot_cloud(\n", + " corrs, \n", + " filtered_features,\n", + " model,\n", + " frequencies=filtered_frequencies,\n", + " cmap=\"YlOrRd\",\n", + " ax=cloud_ax,\n", + " out_fig=cloud_fn,\n", + " )\n", + " else:\n", + " plot_cloud(\n", + " corrs, \n", + " filtered_features,\n", + " model,\n", + " cmap=\"YlOrRd\",\n", + " ax=cloud_ax,\n", + " out_fig=cloud_fn,\n", + " )\n", " \n", - " plot_cloud(features_split, frequencies, corrs, model, cloud_fig, cloud_ax)\n", " if row_i == 0:\n", " cloud_ax.set_title(f\"Segment {seg_id:02d}\", fontsize=8)\n", " # radar_ax.set_title(f\"Segment {seg_id}\", fontsize=8)\n", " if seg_id == 1:\n", - " print(label_dict[model], label_dict[dset_name], label_dict[method])\n", + " # print(label_dict[model], label_dict[dset_name], label_dict[method])\n", " # cloud_axes_tpl[0,0].spines[\"left\"].set_visible(True)\n", " cloud_ax.set_ylabel(\n", " f\"{label_dict[model]}-{label_dict[dset_name]}\\n{label_dict[method]}\", \n", @@ -262,7 +133,7 @@ " \n", " plt.tight_layout(w_pad=0.8, h_pad=0.8)\n", " # plt.subplots_adjust(wspace=0.3, hspace=0.3)\n", - " cloud_fig.savefig(big_cloud_fn, bbox_inches=\"tight\", dpi=1000)\n", + " cloud_fig.savefig(big_cloud_fn, bbox_inches=\"tight\", dpi=300)\n", " #big_radar_fn = op.join(\"./Fig\", \"survey\", f\"radar_{seg_sol}.eps\")\n", " #radar_fig.savefig(big_radar_fn, bbox_inches=\"tight\")\n", " plt.close()\n",