diff --git a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py index 5233d8f..70961d8 100644 --- a/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py +++ b/src/cryo_challenge/_distribution_to_distribution/distribution_to_distribution.py @@ -13,6 +13,13 @@ ) +def sort_by_transport(cost): + m,n = cost.shape + _, transport = compute_wasserstein_between_distributions_from_weights_and_cost(np.ones(m) / m, np.ones(n)/n, cost) + indices = np.argsort((transport * np.arange(m)[...,None]).sum(0)) + return cost[:,indices], indices, transport + + def compute_wasserstein_between_distributions_from_weights_and_cost( weights_a, weights_b, cost, numItermax=1000000 ): diff --git a/tutorials/5_tutorial_plotting.ipynb b/tutorials/5_tutorial_plotting.ipynb index 1503111..ed8a924 100644 --- a/tutorials/5_tutorial_plotting.ipynb +++ b/tutorials/5_tutorial_plotting.ipynb @@ -19,11 +19,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "from cryomethods_comparison_pipeline.distribution_to_distribution import sort_by_transport\n", + "from cryo_challenge._distribution_to_distribution.distribution_to_distribution import sort_by_transport\n", "from cryo_challenge._ploting.plotting_utils import res_at_fsc_threshold\n", "\n", "from dataclasses import dataclass\n", @@ -57,50 +57,60 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "path_to_config = FileChooser(os.path.expanduser(\"~\"))\n", + "display(path_to_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9242b69072f64782b69ff2b63aa23d53", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - "FileChooser(path='/mnt/home/gwoollard', filename='', title='', show_hidden=False, select_desc='Select', changeā€¦" + "{'prob_submitted_plot': {'pkl_fnames': ['/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_0.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_1.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_2.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_3.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_4.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_5.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_6.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_7.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_8.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_9.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_10.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate10_norank_submission_11.pkl']},\n", + " 'emd_plot': {'pkl_globs': ['/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate[0-9]_norank_submission_*.pkl',\n", + " '/mnt/home/gwoollard/ceph/repos/cryomethods_comparison_pipeline/results/distribution_to_distribution_20240416_npoolmicrostate[0-9][0-9]_norank_submission_*.pkl']}}" ] }, + "execution_count": 20, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" } ], - "source": [ - "path_to_config = FileChooser(os.path.expanduser(\"~\"))\n", - "display(path_to_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], "source": [ "@dataclass_json\n", "@dataclass\n", "class PlottingConfig:\n", " gt_metadata: str\n", " map2map_results: List[str]\n", - " dist2dist_results: Dict[str, List[str]]\n", + " dist2dist_results: Dict[str, Dict[str, List[str]]]\n", "\n", - "with open(path_to_config.value, \"r\") as file:\n", + "with open(path_to_config, \"r\") as file:\n", " config = yaml.safe_load(file)\n", - "config = PlottingConfig.from_dict(config)" + "config = PlottingConfig.from_dict(config)\n", + "config.dist2dist_results" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -118,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -136,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -196,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -215,26 +225,27 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs_sorted_d['Cookie Dough'], threshold=0.5)\n", - "units_Angstroms = 2 * 2.146 / (np.arange(1,112+1) / 112)" + "n_fourier_bins = fscs_sorted_d['Cookie Dough'].shape[-1]\n", + "units_Angstroms = 2 * 2.146 / (np.arange(1,n_fourier_bins+1) / n_fourier_bins)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 10, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, @@ -256,7 +267,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -286,10 +297,10 @@ "\n", " for idx, (anonymous_label, fscs) in enumerate(fscs_sorted_d.items()):\n", " # map2map_dist_matrix = data.iloc[gt_ordering].values\n", - " res_fsc_half, fraction_nyquist = res_at_fsc_threshold(fscs, threshold=0.5)\n", + " res_fsc_half, _ = res_at_fsc_threshold(fscs, threshold=0.5)\n", " map2map_dist_matrix = units_Angstroms[res_fsc_half][gt_ordering]\n", "\n", - " sorted_map2map_dist_matrix, indices, transport = sort_by_transport(map2map_dist_matrix)\n", + " sorted_map2map_dist_matrix, _, _ = sort_by_transport(map2map_dist_matrix)\n", "\n", "\n", " ncols = 4\n", @@ -323,11 +334,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ - "fname = config.dist2dist_results['pkl_fnames'][0]\n", + "fname = config.dist2dist_results['prob_submitted_plot']['pkl_fnames'][0]\n", "\n", "with open(fname, 'rb') as f:\n", " data = pickle.load(f)" @@ -335,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -349,12 +360,12 @@ " data_d[anonymous_label] = data\n", " return data_d\n", "\n", - "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['pkl_fnames'])" + "dist2dist_results_d = get_dist2dist_results(config.dist2dist_results['prob_submitted_plot']['pkl_fnames'])" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -383,7 +394,7 @@ " y=0.95)\n", " alpha = 0.05\n", "\n", - " for idx_fname, (key,data) in enumerate(dist2dist_results_d.items()):\n", + " for idx_fname, (_,data) in enumerate(dist2dist_results_d.items()):\n", " \n", "\n", " axes[idx_fname//ncols, idx_fname%ncols].plot(data['user_submitted_populations'], color='black', label='submited')\n", @@ -421,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -461,16 +472,16 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ - "df = wragle_pkl_to_dataframe(config.dist2dist_results['pkl_globs'])" + "df = wragle_pkl_to_dataframe(config.dist2dist_results['emd_plot']['pkl_globs'])" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -505,9 +516,9 @@ " plt.figure(figsize=(plot_width, plot_height), dpi=300)\n", "\n", " # Create a scatter plot for each id\n", - " for i, id in enumerate(ids):\n", - " df_average_id = df_average[df_average['id'] == id]\n", - " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[i % len(markers)], label=id, s=marker_size)\n", + " for idx, id_label in enumerate(ids):\n", + " df_average_id = df_average[df_average['id'] == id_label]\n", + " sns.scatterplot(x='EMD_submitted_norm', y='EMD_opt_norm', data=df_average_id, alpha=alpha, marker=markers[idx % len(markers)], label=id_label, s=marker_size)\n", "\n", " plt.errorbar(x=df_average_and_error['EMD_submitted_norm'], \n", " y=df_average_and_error['EMD_opt_norm'], \n",